diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenGeneral.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenGeneral.h new file mode 100644 index 0000000000000000000000000000000000000000..9b787a2163e87c903ce0bd034b424eb1773c644d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenGeneral.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenOpList.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenOpList.h new file mode 100644 index 0000000000000000000000000000000000000000..1419376a9017db4c7ca788816bd0c9a6d65a82fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATenOpList.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace c10 { +struct OperatorName; +} + +namespace at { + +// check if an op is a custom op (i.e. did not come from native_functions.yaml) +TORCH_API bool is_custom_op(const c10::OperatorName& opName); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_fwd.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_fwd.h new file mode 100644 index 0000000000000000000000000000000000000000..a66523ab183923e1c15862268b1b0260ccfb9e14 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_fwd.h @@ -0,0 +1,46 @@ +#pragma once +#include + +// Forward declarations of core ATen types used in dispatch functions +namespace c10 { + +template +class List; +template +class IListRef; +class Stream; +class Scalar; +class SymInt; +class SymIntList; +struct Storage; +struct TensorOptions; +template +class ArrayRef; +template +class OptionalArrayRef; + +} // namespace c10 + +namespace at { + +class Tensor; +class OptionalTensorRef; +struct Dimname; +struct Generator; +using TensorList = c10::ArrayRef; +using ITensorListRef = c10::IListRef; +using IOptTensorListRef = c10::IListRef; +using DimnameList = c10::ArrayRef; +using IntArrayRef = c10::ArrayRef; +using OptionalIntArrayRef = c10::OptionalArrayRef; +using OptionalSymIntArrayRef = c10::OptionalArrayRef; + +using c10::Stream; +using c10::Storage; +using c10::QScheme; +using c10::Scalar; +using c10::SymInt; +using c10::SymIntList; +using c10::TensorOptions; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_pch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_pch.h new file mode 100644 index 0000000000000000000000000000000000000000..9a097845d003b48a84f1fc31d7f496e51f09d515 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ATen_pch.h @@ -0,0 +1,161 @@ +// This global header must not depend on native_functions.yaml or +// incremental builds will be next to useless +#pragma push_macro("TORCH_ASSERT_NO_OPERATORS") +#define TORCH_ASSERT_NO_OPERATORS + +#include + +// This list of headers was generated using a script that finds +// high-impact headers and then manually tweaked to remove OS specific +// or duplicate headers (e.g. and ) and to remove +// "impl" headers (e.g BFloat16-inl.h or complex_math.h in c10). + +// To generate the initial list: +// 1. Build pytorch from scratch with all build caching disabled +// 2. Generate a build trace with ninjatracing (https://github.com/nico/ninjatracing) +// $ ninjatracing /path/to/pytorch/build/.ninja_log > trace_all.json +// 3. Run pch_gen.py from https://github.com/peterbell10/build_analysis/ +// $ python pch_gen.py --threshold .80 --target torch_cpu --build_dir /path/to/pytorch/build --trace trace_all.json +// Where the threshold can be tweaked until c10 and some of ATen +// core are included but TORCH_ASSERT_NO_OPERATORS still passes. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#pragma pop_macro("TORCH_ASSERT_NO_OPERATORS") diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Array.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Array.h new file mode 100644 index 0000000000000000000000000000000000000000..5f3f7bc9d48747c656ce3c9c60fccce86afafc6c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Array.h @@ -0,0 +1,48 @@ +#pragma once + +// A fixed-size array type usable from both host and +// device code. + +#include +#include + +namespace at::detail { + +template +struct Array { + // NOLINTNEXTLINE(*c-array*) + T data[size_]; + + C10_HOST_DEVICE T operator[](int i) const { + return data[i]; + } + C10_HOST_DEVICE T& operator[](int i) { + return data[i]; + } +#if defined(USE_ROCM) + C10_HOST_DEVICE Array() = default; + C10_HOST_DEVICE Array(const Array&) = default; + C10_HOST_DEVICE Array& operator=(const Array&) = default; + C10_HOST_DEVICE Array(Array&&) = default; + C10_HOST_DEVICE Array& operator=(Array&&) = default; + C10_HOST_DEVICE ~Array() = default; +#else + Array() = default; + Array(const Array&) = default; + Array& operator=(const Array&) = default; + Array(Array&&) noexcept = default; + Array& operator=(Array&&) noexcept = default; + ~Array() = default; +#endif + static constexpr int size() { + return size_; + } + // Fill the array with x. + C10_HOST_DEVICE Array(T x) { + for (int i = 0; i < size_; i++) { + data[i] = x; + } + } +}; + +} // namespace at::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Backtrace.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Backtrace.h new file mode 100644 index 0000000000000000000000000000000000000000..ac728968750297227c1be4aa3e444557c1899b03 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Backtrace.h @@ -0,0 +1,2 @@ +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CachingHostAllocator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CachingHostAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..5049018d731e1e2a54f9dde875f17678eb156d9d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CachingHostAllocator.h @@ -0,0 +1,737 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") +namespace at { + +using c10::CachingAllocator::Stat; +using c10::CachingAllocator::DurationStat; + +/** + * HostBlock is typically a fundamental memory block used in pinned memory. It + * is likely related to Event and Stream of device runtime. It is probably a + * base struct or interface that can be inherited and extended by each backend. + */ +template +struct HostBlock { + // constructor for search key + HostBlock(size_t size) : size_(size) {} + + HostBlock(size_t size, void* ptr) : size_(size), ptr_(ptr) {} + + std::mutex mutex_; + size_t size_{0}; // block size in bytes + void* ptr_{nullptr}; // memory address + bool allocated_{false}; // in-use flag + size_t event_count_{0}; // number of related events + ska::flat_hash_set streams_; // streams on which the block was used +}; + +template +struct alignas(64) FreeBlockList { + std::mutex mutex_; + std::deque list_; +}; + +namespace { + // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes + // NOLINTNEXTLINE(misc-definitions-in-headers) + constexpr size_t MAX_SIZE_INDEX = 64; +} + +// Struct containing memory allocator summary statistics for host. +struct TORCH_API HostStats { + // COUNT: allocations requested by client code. Note that active + // count can be extracted by looking at current allocations + Stat allocation; + // COUNT: number of allocated segments from host memory allocation. + Stat segment; + + // SUM: bytes allocated by this memory alocator. Note that active bytes + // can be extracted by looking at current bytes allocated + Stat allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + Stat reserved_bytes; + + // SUM: time spent in cudaHostAlloc/cudaHostRegister in microseconds + DurationStat host_alloc_time; + + // SUM: time spent in cudaHostFree/cudaHostUnregister in microseconds + DurationStat host_free_time; + + // COUNT: number of times cudaHostAlloc/cudaHostRegister was called because + // the request could not be satisfied from existing free blocks. + int64_t num_host_alloc = 0; // This is derived from segment or timing + + // COUNT: number of times cudaHostFree/cudaHostUnregister was called. + int64_t num_host_free = 0; // This is derived from segment or timing +}; + +// Struct containing memory allocator summary statistics for host, as they +// are staged for reporting. This is a temporary struct that is used to +// avoid locking the allocator while collecting stats. +struct alignas(64) HostStatsStaged { + std::mutex timing_mutex_; + // COUNT: allocations requested by client code resulting in a new segment/block allocation + // LOCK: access to this stat is protected by the allocator's blocks_mutex_ + Stat allocation; + // SUM: bytes within active memory blocks, including blocks that are + // currently in the free list. + // LOCK: access to this stat is protected by the allocator's blocks_mutex_ + Stat allocated_bytes; + // COUNT: number of allocations per bucket + // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_ + std::vector allocation_bucket_stats = std::vector(MAX_SIZE_INDEX); + // SUM: bytes of allocation per bucket + // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_ + std::vector allocated_bytes_bucket_stats = std::vector(MAX_SIZE_INDEX); + // SUM: time spent in cudaHostAlloc/cudaHostRegister + // LOCK: access to this stat is protected by the timing_mutex_ + DurationStat host_alloc_time; + // SUM: time spent in cudaHostFree/cudaHostUnregister + // LOCK: access to this stat is protected by the timing_mutex_ + DurationStat host_free_time; +}; + +/** + * Note [HostAllocator design] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * We have three key data structures - the free list which stores blocks that + * are not currently used, the block list which stores all blocks that have been + * allocated, and the event queue which stores runtime events and their + * corresponding blocks. + * + * Each of these are protected by a separate mutex. The key design principles + * are to 1) only hold each mutex for the minimal amount of time possible, 2) + * never do any possible expensive operations (such as CUDA runtime API calls) + * while holding the lock. + * + * There are four public methods: allocate, free, record_event and empty_cache. + * 1) In the allocate path, we first check to see if we can service our + * request from this free list, and otherwise we create a new block with + * allocate_host_memory. + * 2) In the free path, we insert events (if required) into the event queue, + * and if possible insert our block back into the free list. In allocate, we + * first eagerly query events until we find one that is not ready, and insert + * the corresponding block onto the free list if all the events recorded for a + * block are ready. + * 3) In the record_event path, we simply insert the given stream into the set + * of streams tracked by the specified block. This set of streams is then + * consumed in the free path. + * 4) In the empty_cache path, we flush any available blocks into the free + * list. Remove all element of free list, then remove them from block list and + * release the associated pinned memory allocation via free_block. + * + * We generalize the caching host allocator into two parts: interface and + * implementation. For any new backend looking to integrate with host allocator + * and reuse caching mechanism, these two parts are necessary to be specialized. + * + * For the implementation, we provide a CachingHostAllocatorImpl struct + * to abstract the caching mechanism. Any backend needs to provide a customized + * implementation by specializing its own public functions and the related + * runtime functions. Its template parameter S represents runtime Stream, E + * denotes runtime Event, B indicates the fundamental memory block. + * + * For the interface, we provide a CachingHostAllocatorInterface struct as an + * interface. Any backend needs to derive its own host allocator from this + * interface. Its template parameter T refers to an implementation that + * inherited from CachingHostAllocatorImpl. + * + * So this design can share the caching mechanism across each backend, and + * provide flexibility to each backend. A backend can choose to follow this + * implementation or reuse them by extending and overriding them as necessary. + * Taking CUDA as an example, it specializes runtime related functions to reuse + * the caching mechanism. Additionally, it extends the allocator's functionality + * by adding the allocWithCudaHostRegister function to support page-locking the + * memory range used by CUDA. Of course, you can also refer to + * XPUCachingHostAllocator, which is a host caching allocator supported on XPU + * backend, to implement a basic host caching allocator. + * + * Some of the invariants here are less strict than they could be - for example, + * we do not enforce that free(Block* block) => block->event_count == 0. This is + * for compatibility reasons, and we can explore enforcing these in subsequent + * versions. + * + * Note that this caching host allocator does not split larger allocations into + * smaller blocks, unlike the caching device allocator. + * + * In order to gather statistics about caching host allocator while minimally + * impacting performance, we use a HostStatsStaged struct to stage the stats + * before reporting them. This is done to avoid adding new locks to the allocator. + * Collecting stats is carefully done under existing locks, and then the staged + * stats are converted to the final stats when getStats is called. At that time + * we hold the same locks as empty_cache, to ensure the fidelity of the stats. + */ + +template < + typename S, + typename E, + typename B = HostBlock> +struct CachingHostAllocatorImpl { + virtual ~CachingHostAllocatorImpl() { + active_ = false; + if (pinned_use_background_threads()) { + getBackgroundThreadPool()->waitWorkComplete(); + } + } + + public: + // return data_ptr and block pair. + virtual std::pair allocate(size_t size) { + if (size == 0) { + return {nullptr, nullptr}; + } + + // If we are using background threads, we can process events in the + // background. + if (!pinned_use_background_threads()) { + process_events(); + } + + // Round up the allocation to the nearest power of two to improve reuse. + // These power of two sizes are also used to index into the free list. + size_t roundSize = c10::llvm::PowerOf2Ceil(size); + + // First, try to allocate from the free list + auto* block = get_free_block(roundSize); + if (block) { + return {block->ptr_, reinterpret_cast(block)}; + } + + // Check in the recently freed blocks with pending events to see if we + // can reuse them. Call get_free_block again after processing events + if (pinned_use_background_threads()) { + process_events_for_specific_size(roundSize); + block = get_free_block(roundSize); + if (block) { + return {block->ptr_, reinterpret_cast(block)}; + } + + // Launch the background thread and process events in a loop. + static bool background_thread_flag [[maybe_unused]] = [this] { + getBackgroundThreadPool()->run([&]() { + while (active_) { + process_events(); + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + }); + return true; + }(); + } + + // Slow path: if we can't allocate from the cached free list, we need + // to create a new block. + void* ptr = nullptr; + allocate_host_memory(roundSize, &ptr); + + // Then, create a new block. + block = new B(roundSize, ptr); + block->allocated_ = true; + + add_allocated_block(block); + return {block->ptr_, reinterpret_cast(block)}; + } + + virtual void free(void* ctx) { + if (!ctx) { + return; + } + + // Note: we can assume that free is correctly paired with alloc, and thus we + // do not need to look up the ctx in blocks_. + auto* block = reinterpret_cast(ctx); + + std::optional> events; + { + std::lock_guard g(block->mutex_); + block->allocated_ = false; + if (block->streams_.empty()) { + TORCH_INTERNAL_ASSERT(block->event_count_ == 0); + } else { + events = std::vector(); + events->reserve(block->streams_.size()); + for (auto stream : block->streams_) { + record_stream(events, stream); + } + block->event_count_ += events->size(); + block->streams_.clear(); + } + } + + if (!events) { + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); + stats_.allocation_bucket_stats[index].decrease(1); + stats_.allocated_bytes_bucket_stats[index].decrease(block->size_); + } else { + // restore these events that record by used streams. + std::lock_guard g(events_mutex_); + for (auto&& event : *events) { + events_.emplace_front(std::move(event), block); + } + } + } + + virtual bool record_event(void* ptr, void* ctx, c10::Stream s) { + S stream = S(s); + auto* block = reinterpret_cast(ctx); + + // Note: we need to check if the passed-in `ctx` is valid. This is because + // `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on + // an arbitrary tensor, and is not guaranteed to correspond to a pinned + // memory allocation. Therefore, we need to check that `ctx` is valid before + // proceeding. + { + std::lock_guard g(blocks_mutex_); + if (blocks_.find(block) != blocks_.end()) { + // Now we know this object is safe to access. + std::lock_guard gb(block->mutex_); + TORCH_INTERNAL_ASSERT(block->allocated_); + block->streams_.insert(stream); + return true; + } + auto it = ptr_to_block_.find(ptr); + if (it != ptr_to_block_.end()) { + block = it->second; + std::lock_guard g(block->mutex_); + TORCH_INTERNAL_ASSERT(block->allocated_); + block->streams_.insert(stream); + return true; + } + } + + return false; + } + + virtual void empty_cache() { + // Flush any available blocks into the free_list. + process_events(); + + // Remove all elements from the free list, remove them from the blocks + // list, and free the associated pinned memory allocation. This requires + // concurrently holding both the free list mutexes and the blocks mutex, and + // is the only function that concurrently holds multiple mutexes. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + std::vector blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end()); + free_list_[i].list_.clear(); + + for (auto* block : blocks_to_remove) { + blocks_.erase(block); + ptr_to_block_.erase(block->ptr_); + stats_.allocation.decrease(1); + stats_.allocated_bytes.decrease(block->size_); + free_block(block); + delete block; + } + } + } + + inline size_t size_index(size_t size) { + return c10::llvm::Log2_64_Ceil(size); + } + + virtual bool pinned_use_background_threads() { + return false; + } + + virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data"); + } + + HostStats getStats() { + HostStats stats; + + // To keep getStats lightweight we do *not* flush any available blocks + // into the free_list. This may skew the stats a bit. + + auto add_bucket_stats = [](Stat& accumulator, const Stat& other) { + accumulator.allocated += other.allocated; + accumulator.current += other.current; + accumulator.freed += other.freed; + // Since peaks are measured per bucket independently, we add them up + // to estimate the total peak. This is not strictly correct, but it is + // the best approximation we can get after the fact. + accumulator.peak += other.peak; + }; + + // Accurate reading of memory stats requires concurrently holding both the + // free list mutexes and the blocks mutex. Previously, this was only done in + // empty_cache function. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + // We collect the slow-path stats only once, since they are not collected + // per bucket (we pick index 0 arbitrarily). These are also all the host + // allocations, not taking into account caching and free lists. + if (i == 0) { + stats.segment = stats_.allocation; + stats.reserved_bytes = stats_.allocated_bytes; + stats.num_host_alloc = stats.segment.allocated; + stats.num_host_free = stats.segment.freed; + } + + // Bucket stats need to be merged with the slow-path stats. We do this in + // a best effort manner, since we can't really replay the cached events per bucket. + add_bucket_stats(stats.allocation, stats_.allocation_bucket_stats[i]); + add_bucket_stats(stats.allocated_bytes, stats_.allocated_bytes_bucket_stats[i]); + } + + // Get the timing stats + { + std::lock_guard g(stats_.timing_mutex_); + + stats.host_alloc_time = stats_.host_alloc_time; + stats.host_free_time = stats_.host_free_time; + } + + return stats; + } + + void resetAccumulatedStats() { + // Reseting accumulated memory stats requires concurrently holding both the + // free list mutexes and the blocks mutex. Previously, this was only done in + // empty_cache function. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + if (i == 0) { + stats_.allocation.reset_accumulated(); + stats_.allocated_bytes.reset_accumulated(); + } + stats_.allocation_bucket_stats[i].reset_accumulated(); + stats_.allocated_bytes_bucket_stats[i].reset_accumulated(); + } + + // Also reset timing stats + { + std::lock_guard g(stats_.timing_mutex_); + stats_.host_alloc_time.reset_accumulated(); + stats_.host_free_time.reset_accumulated(); + } + } + + void resetPeakStats() { + // Reseting peak memory stats requires concurrently holding both the + // free list mutexes and the blocks mutex. Previously, this was only done in + // empty_cache function. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + if (i == 0) { + stats_.allocation.reset_peak(); + stats_.allocated_bytes.reset_peak(); + } + stats_.allocation_bucket_stats[i].reset_peak(); + stats_.allocated_bytes_bucket_stats[i].reset_peak(); + } + + // Also reset timing stats + { + std::lock_guard g(stats_.timing_mutex_); + stats_.host_alloc_time.reset_peak(); + stats_.host_free_time.reset_peak(); + } + } + + private: + virtual void add_allocated_block(B* block) { + std::lock_guard g(blocks_mutex_); + blocks_.insert(block); + stats_.allocation.increase(1); + stats_.allocated_bytes.increase(block->size_); + ptr_to_block_.insert({block->ptr_, block}); + + // Unfortunately, we have to, on the slow path, quickly + // lock the bucket to record the allocation. This should + // be a rare event once the cache is warmed up. + auto size = block->size_; + auto index = size_index(size); + { + std::lock_guard g(free_list_[index].mutex_); + stats_.allocation_bucket_stats[index].increase(1); + stats_.allocated_bytes_bucket_stats[index].increase(size); + } + } + + virtual B* get_free_block(size_t size) { + auto index = size_index(size); + std::lock_guard g(free_list_[index].mutex_); + if (!free_list_[index].list_.empty()) { + B* block = free_list_[index].list_.back(); + free_list_[index].list_.pop_back(); + block->allocated_ = true; + stats_.allocation_bucket_stats[index].increase(1); + stats_.allocated_bytes_bucket_stats[index].increase(size); + return block; + } + return nullptr; + } + + virtual void process_events() { + // process all events until the last unready event, not for specific size. + process_events_for_specific_size(-1); + } + + // If size is -1, process all events from backwards until the last unready + // event. Otherwise, process events for a specific size and on first ready block + // is found, add it to the free list and return. + virtual void process_events_for_specific_size(int64_t size) { + size_t event_count = 0; + size_t max_events = 0; + { + std::lock_guard g(events_mutex_); + max_events = events_.size(); + } + + while (true) { + // Avoid calling cudaEventDestroy while holding a mutex, so move + // intermediate events out of the lock into this object. + // process the last event + std::optional> processed; + { + std::lock_guard g(events_mutex_); + if (!events_.empty()) { + processed = std::move(events_.back()); + events_.pop_back(); + } + } + + if (!processed) { + return; + } + + if (size != -1) { + if (event_count++ > max_events) { + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + return; + } + if (size != (int64_t)processed->second->size_) { + // if we are processing a specific size, and the size of the block + // doesn't match, we can't use it. + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + continue; + } + } + + // otherwise, query the event + { + // now, see if we can handle this element + auto& event = processed->first; + if (!query_event(event)) { + // push the event onto the back if it's not ready. + { + std::lock_guard g(events_mutex_); + if (size == -1) { + events_.push_back(std::move(*processed)); + return; + } else { + events_.push_front(std::move(*processed)); + continue; + } + } + } + } + + // Process the events. + TORCH_INTERNAL_ASSERT(processed); + auto* block = processed->second; + bool available = false; + { + std::lock_guard g(block->mutex_); + TORCH_INTERNAL_ASSERT(!block->allocated_) + block->event_count_--; + if (block->event_count_ == 0) { + available = true; + } + } + + if (available) { + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); + stats_.allocation_bucket_stats[index].decrease(1); + stats_.allocated_bytes_bucket_stats[index].decrease(size); + if (size != -1) { + return; + } + } + } + } + + TaskThreadPool* getBackgroundThreadPool() { + static TaskThreadPool* pool = new TaskThreadPool(1); + return pool; + } + + /* These following functions are runtime-related. */ + + // Allocate page-locked memory on the host. + virtual void allocate_host_memory(size_t size, void** ptr) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "Not implemented for allocate_host_memory"); + } + + // Free block and release the pointer contained in block. + virtual void free_block(B* block) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block"); + } + + // Record an event on stream and store event into events. + virtual void record_stream(std::optional>& events, S stream) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream"); + } + + // Query event if it is completed. + virtual bool query_event(E& event) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); + } + + alignas(64) std::mutex blocks_mutex_; + ska::flat_hash_set blocks_; // block list + ska::flat_hash_map ptr_to_block_; + + // We keep free list as a vector of free lists, one for each power of two + // size. This allows us to quickly find a free block of the right size. + // We use deque to store per size free list and guard the list with its own + // mutex. + alignas(64) std::vector> free_list_ = + std::vector>(MAX_SIZE_INDEX); + + alignas(64) std::mutex events_mutex_; + std::deque> events_; // event queue paired with block + + // Indicates whether the object is active. + // Set to false in the destructor to signal background threads to stop. + std::atomic active_{true}; +protected: + alignas(64) HostStatsStaged stats_; +}; + +struct TORCH_API HostAllocator : public at::Allocator { + // Associates the pinned memory allocation with a stream to track + // dependencies. This ensures the memory won't be reused until the stream's + // operations complete + virtual bool record_event(void* ptr, void* ctx, c10::Stream stream) = 0; + + // Frees all cached pinned memory and returns it to the system, clearing the + // allocator's internal cache + virtual void empty_cache() = 0; + + // Returns comprehensive statistics about the allocator's memory usage, + // allocation patterns, and timing metrics + virtual HostStats get_stats() = 0; + + // Resets the cumulative allocation statistics + virtual void reset_accumulated_stats() = 0; + + // Resets the peak memory usage metrics + virtual void reset_peak_stats() = 0; +}; + +template +struct CachingHostAllocatorInterface : public HostAllocator { + CachingHostAllocatorInterface() : impl_(std::make_unique()) {} + + at::DataPtr allocate(size_t size) override { + auto ptr_and_ctx = impl_->allocate(size); + return { + ptr_and_ctx.first, + ptr_and_ctx.second, + deleteFunc, // Use the template parameter deleter function + at::DeviceType::CPU}; + } + + void free(void* ctx) { + impl_->free(ctx); + } + + bool record_event(void* ptr, void* ctx, c10::Stream stream) override { + return impl_->record_event(ptr, ctx, stream); + } + + void empty_cache() override { + impl_->empty_cache(); + } + + void copy_data(void* dest, const void* src, std::size_t count) + const override { + impl_->copy_data(dest, src, count); + } + + HostStats get_stats() override { + return impl_->getStats(); + } + + void reset_accumulated_stats() override { + impl_->resetAccumulatedStats(); + } + + void reset_peak_stats() override { + impl_->resetPeakStats(); + } + + std::unique_ptr impl_; +}; + +#define DECLARE_HOST_ALLOCATOR(name, impl, deleter, instance) \ + void deleter(void* ptr); \ + struct name final \ + : public at::CachingHostAllocatorInterface {}; \ + static name instance; \ + void deleter(void* ptr) { \ + instance.free(ptr); \ + } + +/** + * Set the host allocator for DeviceType `device_type`. This allocator manages + * pinned memory on the host that can be accessed efficiently by the specified + * device type. Note that this function is not thread-safe. + */ +TORCH_API void setHostAllocator( + at::DeviceType device_type, + at::HostAllocator* allocator, + uint8_t priority = 0); + +TORCH_API at::HostAllocator* getHostAllocator(at::DeviceType device_type); + +template +struct HostAllocatorRegistry { + explicit HostAllocatorRegistry(HostAllocator* allocator) { + at::setHostAllocator(device_type, allocator); + } +}; + +#define REGISTER_HOST_ALLOCATOR(device_type, allocator) \ + namespace { \ + static at::HostAllocatorRegistry \ + g_host_allocator_registry_instance(allocator); \ + } + +} // namespace at +C10_DIAGNOSTIC_POP() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CheckMemoryFormat.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CheckMemoryFormat.h new file mode 100644 index 0000000000000000000000000000000000000000..860eec8e7a1f9399467e8db0ee9356468cd0f5b3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/CheckMemoryFormat.h @@ -0,0 +1,24 @@ +#include + +namespace c10::impl { + +inline std::optional +check_tensor_options_and_extract_memory_format( + const TensorOptions& options, + std::optional memory_format) { + TORCH_CHECK( + options.requires_grad_opt() != true, + "Operators taking TensorOptions cannot take a TensorOptions with " + "options.requires_grad set as true. This isn't implemented yet."); + TORCH_CHECK( + !(options.has_memory_format() && memory_format.has_value()), + "Cannot set memory_format both in TensorOptions and explicit argument; please delete " + "the redundant setter."); + if (memory_format.has_value()) { + return memory_format; + } else { + return options.memory_format_opt(); + } +} + +} // namespace impl namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..a945761e8ff97223d47456f1514bb18a3f4ac8bc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + + +namespace at { + +class Tensor; + +// This class specifies a Backend and a ScalarType. Currently, it primarily +// serves as a replacement return value for Tensor::type(). Previously, +// Tensor::type() returned Type&, but we are changing Type to not be +// dtype-specific. +class TORCH_API DeprecatedTypeProperties { + public: + DeprecatedTypeProperties(Backend backend, ScalarType scalar_type) + : backend_(backend), scalar_type_(scalar_type) {} + + Backend backend() const { + return backend_; + } + + Layout layout() const { + return layout_from_backend(backend_); + } + + bool is_sparse() const { + return layout_from_backend(backend()) == kSparse; + } + + bool is_sparse_csr() const { + return layout_from_backend(backend()) == kSparseCsr; + } + + c10::DeviceType device_type() const { + return backendToDeviceType(backend_); + } + + bool is_cuda() const { + return backendToDeviceType(backend_) == kCUDA; + } + + ScalarType scalarType() const { + return scalar_type_; + } + + caffe2::TypeMeta typeMeta() const { + return scalarTypeToTypeMeta(scalar_type_); + } + + bool operator==(const DeprecatedTypeProperties& other) const { + return backend_ == other.backend() && scalar_type_ == other.scalarType(); + } + + bool operator!=(const DeprecatedTypeProperties& other) const { + return !(*this == other); + } + + std::string toString() const { + std::string base_str; + if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) { + base_str = "UndefinedType"; + } else { + base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type"; + } + return base_str; + } + + DeprecatedTypeProperties & toBackend(Backend b) const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + b, scalar_type_); + } + + DeprecatedTypeProperties & toScalarType(ScalarType s) const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + backend_, s); + } + + DeprecatedTypeProperties & cpu() const { + return toBackend(Backend::CPU); + } + + DeprecatedTypeProperties & cuda() const { + return toBackend(Backend::CUDA); + } + + DeprecatedTypeProperties & hip() const { + return toBackend(Backend::HIP); + } + + DeprecatedTypeProperties & privateUser1() const { + return toBackend(Backend::PrivateUse1); + } + + /// Constructs the `TensorOptions` from a type and a `device_index`. + TensorOptions options(int16_t device_index = -1) const { + return TensorOptions().dtype(typeMeta()) + .device(device_type(), static_cast(device_index)) + .layout(layout()); + } + + /// Constructs the `TensorOptions` from a type and a Device. Asserts that + /// the device type matches the device type of the type. + TensorOptions options(std::optional device_opt) const { + if (!device_opt.has_value()) { + return options(-1); + } else { + Device device = device_opt.value(); + AT_ASSERT(device.type() == device_type()); + return options(device.index()); + } + } + + operator TensorOptions() const { + return options(); + } + + int64_t id() const { + return static_cast(backend()) * + static_cast(ScalarType::NumOptions) + + static_cast(scalarType()); + } + + Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const; + Storage unsafeStorageFromTH(void * th_pointer, bool retain) const; + Tensor copy(const Tensor & src, bool non_blocking=false, std::optional to_device={}) const; + + private: + Backend backend_; + ScalarType scalar_type_; +}; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h new file mode 100644 index 0000000000000000000000000000000000000000..78f0cfdfa553050d4437e2873f2ff906627a54c4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h @@ -0,0 +1,33 @@ +#pragma once + +// In order to preserve bc, we make DeprecatedTypeProperties instances unique +// just like they are for Type. + +#include +#include +#include + +namespace at { + +class DeprecatedTypeProperties; + +struct TORCH_API DeprecatedTypePropertiesDeleter { + void operator()(DeprecatedTypeProperties * ptr); +}; + +class TORCH_API DeprecatedTypePropertiesRegistry { + public: + DeprecatedTypePropertiesRegistry(); + + DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) const; + +private: + // NOLINTNEXTLINE(*c-array*) + std::unique_ptr registry + [static_cast(Backend::NumOptions)] + [static_cast(ScalarType::NumOptions)]; +}; + +TORCH_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry(); + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict.h new file mode 100644 index 0000000000000000000000000000000000000000..96cd25fec10bff630686e9917edbb2d33deea41b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict.h @@ -0,0 +1,396 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +struct IValue; +template class Dict; +struct Type; + +namespace impl { + +using valid_dict_key_types = guts::typelist::typelist< + int64_t, + std::string, + double, + c10::complex, + bool, + at::Tensor +>; +} + +namespace detail { + +struct DictKeyHash { + size_t operator()(const IValue& ivalue) const; +}; + +struct DictKeyEqualTo { + bool operator()(const IValue& lhs, const IValue& rhs) const; +}; + +struct DictImpl final : public c10::intrusive_ptr_target { + using dict_map_type = ska_ordered::order_preserving_flat_hash_map; + struct DictElementTypes final { + TypePtr keyType; + TypePtr valueType; + }; + + explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_) + : dict(std::move(dict_)) + , elementTypes(std::move(elementTypes_)) {} + dict_map_type dict; + + DictElementTypes elementTypes; + + intrusive_ptr copy() const; + friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs); +}; + +} + +namespace impl { +template class DictIterator; + +/** + * A reference to an entry in the Dict. + * Use the `key()` and `value()` methods to read the element. + */ +template +class DictEntryRef final { +public: + explicit DictEntryRef(Iterator iterator) + : iterator_(std::move(iterator)) {} + + decltype(auto) key() const { + return iterator_->first.template to(); + } + + decltype(auto) value() const { + return iterator_->second.template to(); + } + + template + void setValue(Value_&& value) const { + static_assert(std::is_constructible_v, "Wrong type for the value argument of setValue()"); + iterator_->second = Value(std::forward(value)); + } + ~DictEntryRef() = default; + +private: + // allow copying and moving, but only our friends (i.e. the Dict class) can do + // it. Copying/moving this reference wrapper would be too ambiguous to allow it + // in the public API. + DictEntryRef(const DictEntryRef&) = default; + DictEntryRef& operator=(const DictEntryRef&) = default; + DictEntryRef(DictEntryRef&&) noexcept = default; + DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default; + + Iterator iterator_; + friend class DictIterator; + friend class Dict; +}; + +// this wraps map_type::iterator to make sure user code can't rely +// on it being the type of the underlying map. +template +class DictIterator final { +public: + // C++17 friendly std::iterator implementation + using iterator_category = std::forward_iterator_tag; + using value_type = DictEntryRef; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + + explicit DictIterator() = default; + ~DictIterator() = default; + + DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {} + DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {} + DictIterator& operator=(const DictIterator& rhs) = default; + DictIterator& operator=(DictIterator&& rhs) noexcept { + entryRef_ = std::move(rhs.entryRef_); + return *this; + } + + DictIterator& operator++() { + ++entryRef_.iterator_; + return *this; + } + + DictIterator operator++(int) { + DictIterator copy(*this); + ++*this; + return copy; + } + + const DictEntryRef& operator*() const { + return entryRef_; + } + + const DictEntryRef* operator->() const { + return &entryRef_; + } + + friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_; + } + +private: + explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {} + + const Iterator& get_iterator_() const { + return entryRef_.iterator_; + } + + friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() == rhs.get_iterator_(); + } + + friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() != rhs.get_iterator_(); + } + + friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() < rhs.get_iterator_(); + } + + friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() <= rhs.get_iterator_(); + } + + friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() > rhs.get_iterator_(); + } + + friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() >= rhs.get_iterator_(); + } + + DictEntryRef entryRef_; + + friend class DictIterator; + friend class Dict; +}; + +template Dict toTypedDict(Dict dict); +template Dict toGenericDict(Dict dict); +} + +/** + * An object of this class stores a map from Key to Value. + * + * This is a pointer type. After a copy, both Dicts + * will share the same storage: + * + * > Dict a; + * > Dict b = a; + * > b.insert(3, "three"); + * > ASSERT("three" == a.at(3)); + * + * We use this class in the PyTorch kernel API because that + * allows us to do optimizations and switch out the underlying + * map implementation without breaking backwards compatibility + * for the kernel API. + */ +template +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class Dict final { +private: + static_assert((std::is_same_v && std::is_same_v) || guts::typelist::contains::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string."); + + // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map. + // We intentionally don't offer conversion from/to + // order_preserving_flat_hash_map, return references to it or something like that, + // because such operations would get expensive if we switch out + // the actual map implementation. + // This is an intrusive_ptr because Dict is a pointer type. + // Invariant: This will never be a nullptr, there will always be a valid + // DictImpl. + c10::intrusive_ptr impl_; + + explicit Dict(c10::intrusive_ptr&& impl); + friend struct IValue; + template friend Dict impl::toTypedDict(Dict); + template friend Dict impl::toGenericDict(Dict); + +public: + using key_type = Key; + using mapped_type = Value; + using size_type = typename detail::DictImpl::dict_map_type::size_type; + using iterator = impl::DictIterator; + + /** + * Creates an empty dict. + */ + explicit Dict(); + + /** + * Create a generic dict with runtime type information. + * This only works for c10::impl::GenericDict and is not part of the public API + * but only supposed to be used internally by PyTorch. + */ + explicit Dict(TypePtr keyType, TypePtr valueType); + + ~Dict() = default; + + Dict(const Dict&) = default; + Dict& operator=(const Dict&) = default; + + /** + * Create a new Dict pointing to a deep copy of the same data. + * The Dict returned is a new dict with separate storage. + * Changes in it are not reflected in the original dict or vice versa. + */ + Dict copy() const; + + /** + * Returns an iterator to the first element of the container. + * If the container is empty, the returned iterator will be equal to end(). + */ + iterator begin() const; + + /** + * Returns an iterator to the element following the last element of the container. + * This element acts as a placeholder; attempting to access it results in undefined behavior. + */ + iterator end() const; + + /** + * Checks if the container has no elements. + */ + bool empty() const; + + /** + * Returns the number of elements in the container. + */ + size_type size() const; + + /** + * Erases all elements from the container. After this call, size() returns zero. + * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators. + */ + void clear() const; + + /** + * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key. + * May invalidate any references, pointers, or iterators referring to contained elements. + * + * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place. + */ + template + std::pair insert(Key_&& key, Value_&& value) const; + + /** + * If an element with the given key already exists, it is overwritten with the given value. + * Otherwise, a new element with the given key and value are inserted. + * May invalidate any references, pointers, or iterators referring to contained elements. + * + * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated. + */ + template + std::pair insert_or_assign(Key_&& key, Value_&& value) const; + + /** + * Removes the element pointed to by iter. + * May invalidate any references, pointers, or iterators referring to contained elements. + * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter. + */ + void erase(iterator iter) const; + + /** + * Removes the element with the given key, if it exists. + * May invalidate any references, pointers, or iterators referring to contained elements. + * + * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't. + */ + [[nodiscard]] size_t erase(const Key& key) const; + + /** + * Returns the mapped value of the element with key equivalent to key. + * If no such element exists, an exception of type std::out_of_range is thrown. + */ + Value at(const Key& key) const; + + /** + * Finds an element with key equivalent to key. + * + * @return Iterator to an element with key equivalent to key. + * If no such element is found, past-the-end (see end()) iterator is returned. + */ + iterator find(const Key& key) const; + + /** + * Checks if there is an element with key equivalent to key in the container. + * + * @return true if there is such an element, otherwise false. + */ + bool contains(const Key& key) const; + + /** + * Increase the capacity so that at least count elements can be stored without + * having to reallocate or rehash. + */ + void reserve(size_type count) const; + + /** + * Value equality comparison. This function implements Python-like semantics for + * equality: two dicts with the same identity (e.g. same pointer) trivially + * compare equal, otherwise each element is compared for equality. + */ + template + friend bool operator==( + const Dict& lhs, + const Dict& rhs); + template + friend bool operator!=( + const Dict& lhs, + const Dict& rhs); + + /** + * Identity comparison. Returns true if and only if `rhs` represents the same + * Dict object as `this`. + */ + bool is(const Dict& rhs) const; + + // private API for now because the return type will change to TypePtr + // instead of std::optional once types are mandatory. + TypePtr keyType() const; + TypePtr valueType() const; + + // [unsafe set type] + // These functions mutate the tagged type of this dictionary in place. + // There is no checking that the members of the dictionary are instances + // of the new types, nor is there a check that other IValues which + // hold references to this dictionary have the right static type. + // This functionality is used only in the unpickler, where at + // creation type the real type of the dictionary is unknown, but + // then later recovered from the static type information of the + // unpickled object. + void unsafeSetKeyType(TypePtr t); + void unsafeSetValueType(TypePtr t); +}; + +namespace impl { +// GenericDict is how IValue stores dicts. It is, however, not part of the +// public API. Kernels should use Dicts with concrete Key, Value types instead +// (maybe except for some internal prim ops). +using GenericDict = Dict; + +} +} + +namespace torch { + template using Dict = c10::Dict; +} + +#include // IWYU pragma: keep diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict_inl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..088fec4e85e28e9d041ca7f1ec7412f8df41bca1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dict_inl.h @@ -0,0 +1,208 @@ +#pragma once + +#include +#include + +namespace c10 { +namespace detail { +inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const { + if (lhs.isTensor() && rhs.isTensor()) { + // for tensors, we compare only by identity (following how it's done in Python). + return lhs.is(rhs); + } + // Otherwise, we first compare by identity for efficiency, then by value (see: + // [container equality]) + return _fastEqualsForContainer(lhs, rhs); +} +} + +template decltype(auto) getTypePtr(); +std::string toString(const Type& type); + +namespace impl { + +template +Dict toTypedDict(GenericDict dict) { + TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr()), ", ", toString(*getTypePtr()), ">. Key types mismatch."); + TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr()), ", ", toString(*getTypePtr()), ">. Value types mismatch."); + + return Dict(std::move(dict.impl_)); +} + +template +GenericDict toGenericDict(Dict dict) { + return GenericDict(std::move(dict.impl_)); +} +} + +namespace detail { + +inline size_t DictKeyHash::operator()(const IValue& ivalue) const { + if (ivalue.isInt()) { + return std::hash()(ivalue.toInt()); + } else if (ivalue.isString()) { + return std::hash()(ivalue.toStringView()); + } else if (ivalue.isDouble()) { + return std::hash()(ivalue.toDouble()); + } else if (ivalue.isComplexDouble()) { + return c10::hash>()(ivalue.toComplexDouble()); + } else if (ivalue.isBool()) { + return std::hash()(ivalue.toBool()); + } else if (ivalue.isTensor()) { + return std::hash()(ivalue.toTensor().unsafeGetTensorImpl()); + } else if (ivalue.isDevice()) { + return std::hash()(ivalue.toDevice()); + } else { + TORCH_CHECK(false, "Can't hash IValues with tag '", ivalue.tagKind(), "'"); + } +} + +inline intrusive_ptr DictImpl::copy() const { + return make_intrusive(dict, elementTypes); +} + +} + +template +Dict::Dict() + :Dict(make_intrusive( + detail::DictImpl::dict_map_type(), + detail::DictImpl::DictElementTypes{getTypePtr(), getTypePtr()})) { + static_assert(!std::is_same_v, "This constructor is not valid for Dict. Please use c10::impl::GenericDict(keyType, valueType) instead."); + static_assert(!std::is_same_v, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead."); +} + +template +Dict::Dict(TypePtr keyType, TypePtr valueType) +: Dict(make_intrusive( + detail::DictImpl::dict_map_type(), + detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) { + static_assert(std::is_same_v, "This constructor is only valid for c10::impl::GenericDict."); + static_assert(std::is_same_v, "This constructor is only valid for c10::impl::GenericDict."); +} + +template +Dict::Dict(c10::intrusive_ptr&& impl): impl_(std::move(impl)) {} + +template +Dict Dict::copy() const { + return Dict(impl_->copy()); +} + +template +typename Dict::iterator Dict::begin() const { + return iterator{impl_->dict.begin()}; +} + +template +typename Dict::iterator Dict::end() const { + return iterator{impl_->dict.end()}; +} + +template +bool Dict::empty() const { + return impl_->dict.empty(); +} + +template +typename Dict::size_type Dict::size() const { + return impl_->dict.size(); +} + +template +void Dict::clear() const { + impl_->dict.clear(); +} + +template +template +std::pair::iterator, bool> Dict::insert(Key_&& key, Value_&& value) const { + static_assert(std::is_constructible_v, "Wrong type for the key argument of Dict::insert"); + static_assert(std::is_constructible_v, "Wrong type for the value argument of Dict::insert"); + auto inserted = impl_->dict.emplace( + Key(std::forward(key)), + Value(std::forward(value))); + return {iterator{inserted.first}, inserted.second}; +} + +template +template +std::pair::iterator, bool> Dict::insert_or_assign(Key_&& key, Value_&& value) const { + static_assert(std::is_constructible_v, "Wrong type for the key argument of Dict::insert_or_assign"); + static_assert(std::is_constructible_v, "Wrong type for the value argument of Dict::insert_or_assign"); + auto inserted = impl_->dict.insert_or_assign( + Key(std::forward(key)), + Value(std::forward(value))); + return {iterator{inserted.first}, inserted.second}; +} + +template +void Dict::erase(iterator iter) const { + impl_->dict.erase(iter.entryRef_.iterator_); +} + +template +[[nodiscard]] size_t Dict::erase(const Key& key) const { + return impl_->dict.erase(key); +} + +template +Value Dict::at(const Key& key) const { + return impl_->dict.at(key).template to(); +} + +template +typename Dict::iterator Dict::find(const Key& key) const { + return iterator{impl_->dict.find(key)}; +} + +template +bool Dict::contains(const Key& key) const { + return end() != find(key); +} + +template +void Dict::reserve(size_type count) const { + impl_->dict.reserve(count); +} + +template +TypePtr Dict::keyType() const { + return impl_->elementTypes.keyType; +} + +template +TypePtr Dict::valueType() const { + return impl_->elementTypes.valueType; +} +template +void Dict::unsafeSetKeyType(TypePtr t) { + impl_->elementTypes.keyType = std::move(t); +} + +template +void Dict::unsafeSetValueType(TypePtr t) { + impl_->elementTypes.valueType = std::move(t); +} + +template +bool operator==(const Dict& lhs, const Dict& rhs) { + // Dicts with the same identity trivially compare equal. + if (lhs.impl_ == rhs.impl_) { + return true; + } + + // Otherwise compare the values + return *lhs.impl_ == *rhs.impl_; +} + +template +bool operator!=(const Dict& lhs, const Dict& rhs) { + return !(lhs == rhs); +} + +template +bool Dict::is(const Dict& rhs) const { + return this->impl_ == rhs.impl_; +} +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DimVector.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..576b9e142ebf17d048262763c888845a7d0386e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DimVector.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace at { + +// Re-declaring 'DimVector' type and size inside 'at' namespace. +// This is done to avoid modifying every use into their 'c10' +// equivalent. + +using c10::kDimVectorStaticSize; +using c10::DimVector; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dimname.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dimname.h new file mode 100644 index 0000000000000000000000000000000000000000..e5bb3c13922866931a59bd49b7f96cdbac81fc5d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Dimname.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { + +enum class NameType: uint8_t { BASIC, WILDCARD }; + +struct TORCH_API Dimname { + static Dimname fromSymbol(Symbol name); + static Dimname wildcard(); + static bool isValidName(const std::string& name); + + NameType type() const { return type_; } + Symbol symbol() const { return name_; } + + bool isBasic() const { return type_ == NameType::BASIC; } + bool isWildcard() const { return type_ == NameType::WILDCARD; } + + bool matches(Dimname other) const; + std::optional unify(Dimname other) const; + + private: + Dimname(Symbol name) + : name_(name), type_(NameType::BASIC) {} + Dimname(Symbol name, NameType type) + : name_(name), type_(type) {} + + Symbol name_; + NameType type_; +}; + +using DimnameList = c10::ArrayRef; + +TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname); + +inline bool operator==(const Dimname& lhs, const Dimname& rhs) { + return lhs.symbol() == rhs.symbol(); +} + +inline bool operator!=(const Dimname& lhs, const Dimname& rhs) { + return !(lhs == rhs); +} + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DistributionsHelper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DistributionsHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..bbf8c648fca505c1d59276b050f8903978ac6555 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/DistributionsHelper.h @@ -0,0 +1,332 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +/** + * Distributions kernel adapted from THRandom.cpp + * The kernels try to follow std::random distributions signature + * For instance: in ATen + * auto gen = at::detail::createCPUGenerator(); + * at::uniform_real_distribution uniform(0, 1); + * auto sample = uniform(gen.get()); + * + * vs std::random + * + * std::mt19937 gen; + * std::uniform_real_distribution uniform(0, 1); + * auto sample = uniform(gen); + */ + + +namespace at { +namespace { + +/** + * Samples a discrete uniform distribution in the range [base, base+range) of type T + */ +template +struct uniform_int_from_to_distribution { + + C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) : range_(range), base_(base) {} + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { +#ifdef FBCODE_CAFFE2 + if (( + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) && range_ >= 1ULL << 32) +#else + if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using % +#endif + { + return transformation::uniform_int_from_to(generator->random64(), range_, base_); + } else { + return transformation::uniform_int_from_to(generator->random(), range_, base_); + } + } + + private: + uint64_t range_; + int64_t base_; +}; + +/** + * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)] + */ +template +struct uniform_int_full_range_distribution { + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + return transformation::uniform_int_full_range(generator->random64()); + } + +}; + +/** + * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types + * and [0, 2^mantissa] for floating-point types. + */ +template +struct uniform_int_distribution { + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + if constexpr (std::is_same_v || std::is_same_v) { + return transformation::uniform_int(generator->random64()); + } else { + return transformation::uniform_int(generator->random()); + } + } + +}; + +/** + * Samples a uniform distribution in the range [from, to) of type T + */ +template +struct uniform_real_distribution { + + C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) : from_(from), to_(to) { + TORCH_CHECK_IF_NOT_ON_CUDA(from <= to); + TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits::max()); + } + + template + C10_HOST_DEVICE inline dist_acctype operator()(RNG generator){ + if constexpr (std::is_same_v) { + return transformation::uniform_real(generator->random64(), from_, to_); + } else { + return transformation::uniform_real(generator->random(), from_, to_); + } + } + + private: + T from_; + T to_; +}; + +// The SFINAE checks introduced in #39816 looks overcomplicated and must revisited +// https://github.com/pytorch/pytorch/issues/40052 +#define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \ +template \ +struct has_member_##member \ +{ \ + typedef char yes; \ + typedef long no; \ + template static yes test(decltype(&U::member)); \ + template static no test(...); \ + static constexpr bool value = sizeof(test(0)) == sizeof(yes); \ +} + +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample); +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample); +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample); +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample); + +#define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \ + \ +template ::value && \ + has_member_set_next_##TYPE##_normal_sample::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \ + if (generator->next_##TYPE##_normal_sample()) { \ + *ret = *(generator->next_##TYPE##_normal_sample()); \ + generator->set_next_##TYPE##_normal_sample(std::optional()); \ + return true; \ + } \ + return false; \ +} \ + \ +template ::value || \ + !has_member_set_next_##TYPE##_normal_sample::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type* /*ret*/) { \ + return false; \ +} \ + \ +template ::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \ + generator->set_next_##TYPE##_normal_sample(cache); \ +} \ + \ +template ::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type /*cache*/) { \ +} + +DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double) +DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float) + +/** + * Samples a normal distribution using the Box-Muller method + * Takes mean and standard deviation as inputs + * Note that Box-muller method returns two samples at a time. + * Hence, we cache the "next" sample in the CPUGeneratorImpl class. + */ +template +struct normal_distribution { + + C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in); + } + + template + C10_HOST_DEVICE inline dist_acctype operator()(RNG generator){ + dist_acctype ret; + // return cached values if available + if constexpr (std::is_same_v) { + if (maybe_get_next_double_normal_sample(generator, &ret)) { + return transformation::normal(ret, mean, stdv); + } + } else { + if (maybe_get_next_float_normal_sample(generator, &ret)) { + return transformation::normal(ret, mean, stdv); + } + } + // otherwise generate new normal values + uniform_real_distribution uniform(0.0, 1.0); + const dist_acctype u1 = uniform(generator); + const dist_acctype u2 = uniform(generator); + const dist_acctype r = ::sqrt(static_cast(-2.0) * ::log1p(-u2)); + const dist_acctype theta = static_cast(2.0) * c10::pi * u1; + if constexpr (std::is_same_v) { + maybe_set_next_double_normal_sample(generator, r * ::sin(theta)); + } else { + maybe_set_next_float_normal_sample(generator, r * ::sin(theta)); + } + ret = r * ::cos(theta); + return transformation::normal(ret, mean, stdv); + } + + private: + T mean; + T stdv; +}; + +template +struct DiscreteDistributionType { using type = float; }; + +template <> struct DiscreteDistributionType { using type = double; }; + +/** + * Samples a bernoulli distribution given a probability input + */ +template +struct bernoulli_distribution { + + C10_HOST_DEVICE inline bernoulli_distribution(T p_in) : p(p_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1); + } + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::bernoulli(uniform(generator), p); + } + + private: + T p; +}; + +/** + * Samples a geometric distribution given a probability input + */ +template +struct geometric_distribution { + + C10_HOST_DEVICE inline geometric_distribution(T p_in) : p(p_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1); + } + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::geometric(uniform(generator), p); + } + + private: + T p; +}; + +/** + * Samples an exponential distribution given a lambda input + */ +template +struct exponential_distribution { + + C10_HOST_DEVICE inline exponential_distribution(T lambda_in) : lambda(lambda_in) {} + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::exponential(uniform(generator), lambda); + } + + private: + T lambda; +}; + +/** + * Samples a cauchy distribution given median and sigma as inputs + */ +template +struct cauchy_distribution { + + C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) : median(median_in), sigma(sigma_in) {} + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::cauchy(uniform(generator), median, sigma); + } + + private: + T median; + T sigma; +}; + +/** + * Samples a lognormal distribution + * Takes mean and standard deviation as inputs + * Outputs two samples at a time + */ +template +struct lognormal_distribution { + + C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0); + } + + template + C10_HOST_DEVICE inline T operator()(RNG generator){ + normal_distribution normal(mean, stdv); + return transformation::log_normal(normal(generator)); + } + + private: + T mean; + T stdv; +}; +} +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Formatting.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Formatting.h new file mode 100644 index 0000000000000000000000000000000000000000..db3d461915bde6b5aebd0ff0b65b36de386c730a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Formatting.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { +TORCH_API std::ostream& operator<<(std::ostream& out, Backend b); +TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s); +TORCH_API std::string toString(const Scalar& s); +} +namespace at { + +TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t); +TORCH_API std::ostream& print( + std::ostream& stream, + const Tensor& tensor, + int64_t linesize); +inline std::ostream& operator<<(std::ostream & out, const Tensor & t) { + return print(out,t,80); +} +TORCH_API void print(const Tensor & t, int64_t linesize=80); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Generator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Generator.h new file mode 100644 index 0000000000000000000000000000000000000000..297b805f407b15db6d1e7a10ea199f3559fc9f63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Generator.h @@ -0,0 +1,191 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +// For the record I don't think this is a correct pimpl idiom. +// Including Impl header in interface header defeats the purpose +// because you can't change Impl private members without forcing +// everything that included the interface to rebuild. +// Impl should be forward-declared in the interface header instead. +#include + +/** + * Note [Generator] + * ~~~~~~~~~~~~~~~~ + * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to + * generate a seemingly random sequence of numbers, that may be later be used in creating + * a random distribution. Such an engine almost always maintains a state and requires a + * seed to start off the creation of random numbers. Often times, users have + * found it beneficial to be able to explicitly create, retain, and destroy + * PRNG states and also be able to have control over the seed value. + * + * A Generator in ATen gives users the ability to read, write and modify a PRNG engine. + * For instance, it does so by letting users seed a PRNG engine, fork the state of the + * engine, etc. + * + * By default, there is one generator per device, and a device's generator is + * lazily created. A user can use the torch.Generator() api to create their own generator. + */ + +/** + * Note [Acquire lock when using random generators] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Generator and its derived classes are NOT thread-safe. Please note that most of the + * places where we have inserted locking for generators are historically based, and we + * haven't actually checked that everything is truly thread safe (and it probably isn't). + * Please use the public mutex_ when using any methods from these classes, except for the + * read-only methods. You can learn about the usage by looking into the unittests + * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard. + * + * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making + * them non-thread safe and instead making the generator state splittable, to accommodate + * forks into other threads). + */ + +namespace at { + +class Tensor; + +struct TORCH_API Generator { + Generator() = default; + + explicit Generator(c10::intrusive_ptr gen_impl) + : impl_(std::move(gen_impl)) { + if (impl_.get() == nullptr) { + throw std::runtime_error("GeneratorImpl with nullptr is not supported"); + } + } + + bool operator==(const Generator& rhs) const { + return this->impl_ == rhs.impl_; + } + + bool operator!=(const Generator& rhs) const { + return !((*this) == rhs); + } + + bool defined() const { + return static_cast(impl_); + } + + c10::GeneratorImpl* unsafeGetGeneratorImpl() const { + return impl_.get(); + } + + c10::GeneratorImpl* unsafeReleaseGeneratorImpl() { + return impl_.release(); + } + + const c10::intrusive_ptr& getIntrusivePtr() const { + return impl_; + } + + void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); } + // Sets the offset of Generator state to the desired offset. This is currently + // supported for only Philox based Generators, i.e., CUDA and MPS. + void set_offset(uint64_t offset) { impl_->set_offset(offset); } + + // Returns the offset of Generator state. This is currently supported for only + // Philox based Generators, i.e., CUDA and MPS. + uint64_t get_offset() const { return impl_->get_offset(); } + + uint64_t current_seed() const { return impl_->current_seed(); } + + uint64_t seed() { return impl_->seed(); } + + // Implementation not inlined to prevent cycle reference between + // `ATen/core/Generator.h` and `ATen/core/Tensor.h` + void set_state(const at::Tensor& new_state); + + at::Tensor get_state() const; + + void graphsafe_set_state(const Generator& new_state); + + Generator graphsafe_get_state() const; + + std::mutex& mutex() { + return impl_->mutex_; + } + + DispatchKeySet key_set() const { + return impl_->key_set(); + } + + Device device() const { return impl_->device(); } + + inline void set_pyobj(PyObject* pyobj) const noexcept { + impl_->set_pyobj(pyobj); + } + + inline PyObject* pyobj() const noexcept { + return impl_->pyobj(); + } + + template + T* get() const { return static_cast(impl_.get()); } + + Generator clone() const { + return Generator(impl_->clone()); + } + + private: + c10::intrusive_ptr impl_; +}; + +template +Generator make_generator(Args&&... args) { + return Generator(c10::make_intrusive(std::forward(args)...)); +} + +/** + * Utility function to static cast input Generator* to + * the backend generator type (CPU/CUDAGeneratorImpl etc.) + */ +template +inline T * check_generator(std::optional gen) { + TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt"); + TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed"); + TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'"); + return gen->get(); +} + +/** + * Utility function used in tensor implementations, which + * supplies the default generator to tensors, if an input generator + * is not supplied. The input Generator* is also static casted to + * the backend generator type (CPU/CUDAGeneratorImpl etc.) + */ +template +inline T* get_generator_or_default(const std::optional& gen, const Generator& default_gen) { + return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen); +} + +namespace detail { + +/** + * Helper function for checking the validity of new random generator + * state. Right now following conditions are checked: + * + * - The new state tensor must be a torch.ByteTensor + * - Data of the new state tensor must be contiguous + */ +inline void check_rng_state(const c10::TensorImpl& new_state) { + TORCH_CHECK_TYPE( + new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte, + "RNG state must be a torch.ByteTensor" + ); + + TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous"); +} + +} // namespace detail + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h new file mode 100644 index 0000000000000000000000000000000000000000..a4879a1f5f5c78b65e353d342c7eeff6e2dac259 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include + +namespace at { + +using GeneratorFuncType = std::function; + +TORCH_API std::optional& GetGeneratorPrivate(); + +class TORCH_API _GeneratorRegister { + public: + explicit _GeneratorRegister(const GeneratorFuncType& func); +}; + +TORCH_API at::Generator GetGeneratorForPrivateuse1( + c10::DeviceIndex device_index); + +/** + * This is used to register Generator to PyTorch for `privateuse1` key. + * + * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1) + * + * class CustomGeneratorImpl : public c10::GeneratorImpl { + * CustomGeneratorImpl(DeviceIndex device_index = -1); + * explicit ~CustomGeneratorImpl() override = default; + * ... + * }; + * + * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) { + * return at::make_generator(id); + * } + */ + +#define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \ + static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate); + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef.h new file mode 100644 index 0000000000000000000000000000000000000000..aa90faf838786cf23b1e82087ca70c6ce28669eb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef.h @@ -0,0 +1,631 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +/* + * [Note: IListRef] + * Wrapper around different API containers (e.g. boxed and unboxed). + * + * What is it? + * =========== + * It is a tagged union of both boxed and unboxed API containers. + * Working implementations: + * + * - `IListRef` + * - `IListRef` + * + * Note that `IListRef` is a view type. Meaning that it won't own the + * tensors it holds. It's intended to be used only as argument parameters. + * Specifically, where these 2 worlds overlap. + * + * What is this for? + * ================= + * Historically, PyTorch has maintained 2 different APIs: the unboxed + * (called from C++ API and Python eager mode) and boxed APIs (called + * from the TorchScript JIT, mobile interpreter, and boxed fallbacks). + * + * Calling unboxed kernels from the boxed "world" and vice-versa may + * result in non-negligible overhead. Lists are one of those types: + * + * - Boxed world: `c10::List` + * - Unboxed world: `c10::ArrayRef` + * + * In this context, `c10::IListRef` solves this problem by wrapping those + * 2 container types, so that we don't need to convert from one to + * the other. + * + * (see https://github.com/pytorch/pytorch/issues/66328) + * + * What does it do? + * ================ + * This container wraps around the different tagged containers + * (currently, only boxed and unboxed), without incurring in extra + * overhead for converting from one to another. It does so while + * exposing usual container methods, which dispatch to corresponding + * implementations. + * + * While it works with different container types, it introduces + * overhead for repeatedly calling member functions (since those will + * get dispatched, again). Therefore, you should only use it to iterate + * through the list up to one time. If you need to do more complex things, + * call `materialize()` first. + * + * Adding support for a new Tag + * ============================ + * Suppose we want to add a new tag: `Chest`. Here are the steps + * we would have to go through: + * + * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`. + * + * #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \ + * ... + * _(Chest, ##__VA_ARGS__) + * + * 2. Add type aliases, union members, and constructors. + * + * template + * class IListRef { + * ... + * using chest_type = + * typename detail::IListRefTagImpl::list_type; + * ... + * IListRef(...) : tag_(IListRefTag::Chest) { + * ... + * } + * ... + * union Payload { + * ... + * chest_type chest; + * ... + * }; + * ... + * }; + * + * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's + * preferable to make the default implementation work for `T = Tensor` + * (both `Unboxed` and `Boxed` do it). + * + * template + * class IListRefTagImplBase { + * public: + * using elem_type = ListElemT; + * using list_type = ChestContainer; + * + * static const list_type& unwrap(const IListRef& ilist) { ... } + * + * static typename list_type::const_iterator& unwrap( + * IListRefIterator& it) { ... } + * + * static const typename list_type::const_iterator& unwrap( + * const IListRefIterator& it) { ... } + * + * static IListRefConstRef iterator_get( + * const typename list_type::const_iterator& it) { ... } + * } + * + * 4. Add an specialization for each of the already supported types. + * Finally, for consistency, add them to the tracking list. + * (see [Note: IListRefTagImpl Specializations]) + * + * template <> + * class IListRefTagImpl + * : public IListRefTagImplBase {}; + * + * Adding support for a new Type + * ============================= + * Suppose we want to add support for a new type: `Matrix`. + * Here are the steps we would have to go through: + * + * 1. Add an specialization for each of the existing tags. + * For consistency, add them to the tracking list. + * (see [Note: IListRefTagImpl Specializations]) + * + * template <> + * class IListRefTagImpl + * : public IListRefTagImplBase {}; + * + * template <> + * class IListRefTagImpl + * : public IListRefTagImplBase {}; + * + * Common Problems + * =============== + * 1. One of `IListRef(Iterator)` methods are failing to compile. + * + * That may be happening because the container type you added + * is not compatible with the code written for that method. If + * that's true, then you might have to transform that code into + * a static method call (see `List::operator[]` method). + * + * 2. Can't make `IListRefIterator::operator*` return a const-reference. + * + * First, keep in mind that we assume that boxed containers will + * have to deal with `IValue` (e.g. `c10::List`). In this context, + * what may be happening is that `IValue` doesn't store internally + * your type `T`. Instead, it constructs a type new `T` everytime + * you try to get `T` for it (see `IListRef`). + */ + +namespace c10 { +template +class IListRef; + +/* + * Applies arbitrary macros to each `IListRefTag`. + */ +#define TORCH_ILISTREF_FORALL_TAGS(_, ...) \ + _(Unboxed, ##__VA_ARGS__) \ + _(Boxed, ##__VA_ARGS__) \ + _(Materialized, ##__VA_ARGS__) + +/* + * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`, + * while bringing to scope: + * + * - `ImplT`: the implementation class for `TAG` + * - `this_`: the result of unwrapping `this` + */ +#define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY) \ + case c10::IListRefTag::TAG: { \ + using ImplT = c10::detail::IListRefTagImpl; \ + auto& this_ = ImplT::unwrap(*this); \ + BODY \ + } break; + +/* + * Dispatches the unwrap call, depending on `TAG`, followed by + * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`. + * + * This macro is useful because it allows us to handle different + * types (that correspond to different tags) to be implemented + * only once. We can do it even when the implementation of the + * different tags aren't syntatically the same, by dispatching + * it to a function (e.g. `ImplT::(this_)`). + */ +#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \ + switch (TAG) { \ + TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \ + break; \ + default: \ + TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \ + } + +enum class IListRefTag { +#define DEFINE_TAG(tag, ...) tag, + TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG) +#undef DEFINE_TAG + None +}; + +namespace detail { +/* + * Type alias that specifies whether we return a reference or a copy of `T`. + * + * What is this for? + * ================= + * Since values in the boxed world are represented by an `IValue`, we also + * depend on whether it can be converted to a const-reference (`Tensor`) or + * has to create a new copy of `T` (`OptionalTensorRef`). + */ +template +using IListRefConstRef = typename ivalue_to_const_ref_overload_return::type; + +/* + * Interface that implements key functions for each `IListRefTag` type. + * + * What is this for? + * ================= + * Given an `IListRef(Iterator)`, some methods have to be implemented + * differently for each `TAG`. Therefore, the methods inside this class + * are used as dispatch targets for the different `IListRefTag` values. + * + * You should create an specialization of this class for each possible + * combination of `IListRefTag` type (except `None`) and element types + * (e.g. `Tensor`). + * + * What does it do? + * ================ + * 1. defines static methods to be used as dispatch targets by both + * `IListRef` and `IListRefIterator` (see the implementation of + * `IListRefTagImplBase`). + * + * 2. defines the `elem_type` and `list_type` aliases that will be + * used in the definition of `IListRef`. In general, we should do + * so by inheriting from `IListRefTagImplBase`. + * + * [Note: IListRefTagImpl Specialization] + * ====================================== + * For `IListRef(Iterator)`: + * - + * - + * - + * + * For `IListRef(Iterator)`: + * - + * - + * - + */ +template +class IListRefTagImpl {}; + +/* + * Base implementation of `IListRefTagImpl` methods. + * + * What is this for? + * ================= + * This should make adding specializations for new types easier. For + * example, one should be able to add a new type just by making its + * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`. + * + * You should create a partial specialization for this class only if + * you introduce a new `IListRefTag`. The idea being that there is one + * default implementation for each possible value of `IListRefTag`. + * + * What does it do? + * ================ + * 1. defines `elem_type` as an alias to `ListElemT`. + * + * 1. defines `list_type` as an alias to the default container type + * that will hold a collection of `elem_type`. The idea being that + * all types tagged as `TAG` will have `list_type` as its container, + * with different `elem_type`. + * + * 3. defines the default implementation for each of the methods that + * are supposed to be defined on `IListRefTagImpl` specializations. + * + * 4. inheriting from `IListRefTagImplBase` also means + * that the payload of the type `IListRef` will be of type `list_type` + * when it is tagged as `TAG`. + */ +template +class IListRefTagImplBase {}; + +/* + * Materialized container for `IListRef`. + * + * What is this for? + * ================= + * Container that groups `T` references together. This exchanges the + * overhead of every method call from `IListRef` for a dynamic allocation. + * + * You should use this container instead of `IListRef` if: + * + * - You are going to iterate the list more than once + * - You need to repeatedly access arbitrary elements (using `operator[]`) + * What does it do? + + * ================ + * Removes the reference (&) from the type, and wraps it into a + * `std::reference_wrapper`. If `IListRefConstRef` is not a + * reference type, then it's left unchanged. + */ +template +using _MaterializedIListRefElem = std::conditional_t< + std::is_reference_v, + typename std::reference_wrapper>, + T>; + +template +using MaterializedIListRefElem = _MaterializedIListRefElem>; + +template +using MaterializedIListRef = std::vector>; + +} // namespace detail + +/* + * Iterator for `IListRef`. + * + * What is it? + * =========== + * Currently, a `std::bidirectional_iterator` that wraps the iterator + * types defined for each of the `IListRefTag`. + * + * One should be able to use it, as if it were the unwrapped + * iterators themselves. + + * What does it do? + * ================ + * Similarly to `IListRef`, this is a wrapper class. Specifically, it + * wraps each container's `const_iterator` type alias. So, for example, + * given that the container for `IListRefTag::Boxed` is `c10::List`, this + * iterator will wrap a `c10::List::const_iterator`. + * + * [Note: MSVC Iterator Debug] + * =========================== + * MSVC `vector::iterator` implementation (used in the boxed variant) + * makes it so this union's destructor, copy-constructor (assignment), and + * move-constructor (assignment) are implicitly deleted. + * + * Therefore, we need to explicitly define them as needed. Follows a list + * of places where these are needed and their reason: + * + * - `Payload` destructor: + * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2. + * + * - `IListRefIterator` destructor: + * same as above. However, we need to explicitly call the variant + * destructor explicitly. + * + * - `IListRefIterator` copy-constructor: + * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different + * than 0. + */ +template +class IListRefIterator { + private: +#define DEFINE_FRIEND_CLASS(TAG, ...) \ + friend class detail::IListRefTagImpl; \ + friend class detail::IListRefTagImplBase< \ + IListRefTag::TAG, \ + T, \ + typename detail::IListRefTagImpl::elem_type>; + TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS) +#undef DEFINE_FRIEND_CLASS + + public: + // C++17 friendly std::iterator implementation + using iterator_category = std::bidirectional_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = T&; + + using unboxed_iterator_type = typename detail:: + IListRefTagImpl::list_type::const_iterator; + using boxed_iterator_type = typename detail:: + IListRefTagImpl::list_type::const_iterator; + using materialized_iterator_type = + typename detail::MaterializedIListRef::const_iterator; + + IListRefIterator() : tag_(IListRefTag::None) {} + +#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0 + // See [Note: MSVC Iterator Debug] + IListRefIterator(const IListRefIterator& iterator) + : tag_(iterator.tag_) { + switch (tag_) { + case IListRefTag::Boxed: + payload_.boxed_iterator = iterator.payload_.boxed_iterator; + break; + case IListRefTag::Unboxed: + payload_.unboxed_iterator = iterator.payload_.unboxed_iterator; + break; + case IListRefTag::Materialized: + payload_.materialized_iterator = iterator.payload_.materialized_iterator; + break; + default: + TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); + } + } +#endif + +#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2 + // See [Note: MSVC Iterator Debug] + ~IListRefIterator() noexcept(false) { + switch (tag_) { + case IListRefTag::Boxed: + payload_.boxed_iterator.~boxed_iterator_type(); + break; + case IListRefTag::Unboxed: + payload_.unboxed_iterator.~unboxed_iterator_type(); + break; + case IListRefTag::Materialized: + payload_.materialized_iterator.~materialized_iterator_type(); + break; + default: + TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); + } + } +#endif + + IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) { + payload_.boxed_iterator = boxed; + } + + IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) { + payload_.unboxed_iterator = unboxed; + } + + IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) { + payload_.materialized_iterator = materialized; + } + + detail::IListRefConstRef operator*() const { + TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); }); + } + + IListRefIterator& operator++() { + TORCH_ILISTREF_UNWRAP(tag_, { ++this_; }); + return *this; + } + + IListRefIterator operator++(int) { + auto old = *this; + TORCH_ILISTREF_UNWRAP(tag_, { ++this_; }); + return old; + } + + IListRefIterator& operator--() { + TORCH_ILISTREF_UNWRAP(tag_, { --this_; }); + return *this; + } + + IListRefIterator operator--(int) { + auto old = *this; + TORCH_ILISTREF_UNWRAP(tag_, { --this_; }); + return old; + } + + bool operator==(const IListRefIterator& rhs) const { + if (tag_ != rhs.tag_) { + return false; + } + TORCH_ILISTREF_UNWRAP(tag_, { + auto& rhs_it = ImplT::unwrap(rhs); + return this_ == rhs_it; + }); + } + + bool operator!=(const IListRefIterator& rhs) const { + return !(*this == rhs); + } + + private: + union Payload { + boxed_iterator_type boxed_iterator; + unboxed_iterator_type unboxed_iterator; + materialized_iterator_type materialized_iterator; + void* _init_ptr; + Payload() : _init_ptr(nullptr) {} +#if defined(_MSC_VER) + // See [Note: MSVC Iterator Debug] + ~Payload() {} +#endif + }; + + Payload payload_; + IListRefTag tag_; +}; + +/* + * See [Note: IListRef] + */ +template +class IListRef { + private: +#define DEFINE_FRIEND_CLASS(TAG, ...) \ + friend class detail::IListRefTagImpl; \ + friend class detail::IListRefTagImplBase< \ + IListRefTag::TAG, \ + T, \ + typename detail::IListRefTagImpl::elem_type>; + TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS) +#undef DEFINE_FRIEND_CLASS + + public: + using unboxed_type = + typename detail::IListRefTagImpl::list_type; + using boxed_type = + typename detail::IListRefTagImpl::list_type; + using materialized_type = + typename detail::MaterializedIListRef; + + using iterator = IListRefIterator; + using const_iterator = IListRefIterator; + using reverse_iterator = std::reverse_iterator; + using value_type = typename iterator::value_type; + + IListRef() : tag_(IListRefTag::None) {} + + IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) { + payload_.boxed = &boxed; + } + + IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) { + payload_.unboxed = unboxed; + } + + IListRef(const std::initializer_list& list) : tag_(IListRefTag::Unboxed) { + payload_.unboxed = at::ArrayRef(list); + } + + template < + typename... UnboxedConstructorArgs, + typename = std::enable_if_t< + std::is_constructible_v>> + IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) { + payload_.unboxed = unboxed_type(std::forward(args)...); + } + + IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) { + payload_.materialized = &materialized; + } + + size_t size() const { + TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); }); + } + + bool empty() const { + return size() == 0; + } + + iterator begin() const { + TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); }); + } + + iterator end() const { + TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); }); + } + + detail::IListRefConstRef front() const { + TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); }); + } + + /* + * Materializes the `IListRef` into a `std::vector`. + * + * This should be used when one wishes to either: + * + * - iterate over the list more than once: each `IListRefIterator` + * member function call has to go through a switch, introducing + * non-negligible overhead + * + * - randomly access an arbitrary element using `operator[]`: + * same reason as above + */ + detail::MaterializedIListRef materialize() const { + if (isMaterialized()) { + return toMaterialized(); + } + + detail::MaterializedIListRef materialized; + materialized.reserve(size()); + for (const auto& t : *this) { + materialized.emplace_back(t); + } + return materialized; + } + +#define DEFINE_CHECK(TAG, ...) \ + bool is##TAG() const { \ + return tag_ == IListRefTag::TAG; \ + } + TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK) +#undef DEFINE_CHECK + + bool isNone() const { + return tag_ == IListRefTag::None; + } + +#define DEFINE_CASTING(TAG, ...) \ + const typename detail::IListRefTagImpl::list_type& \ + to##TAG() const { \ + TORCH_INTERNAL_ASSERT(is##TAG()); \ + return detail::IListRefTagImpl::unwrap(*this); \ + } + TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING) +#undef DEFINE_CASTING + + private: + union Payload { + const boxed_type* boxed; + unboxed_type unboxed; + const materialized_type* materialized; + Payload() : boxed(nullptr) {} + }; + + Payload payload_; + IListRefTag tag_; +}; + +} // namespace c10 + +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef_inl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..df320c13d9c2381a373d608a504fa9ed6ee778b2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/IListRef_inl.h @@ -0,0 +1,203 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; +class OptionalTensorRef; +} + + +namespace c10::detail { + +/* + * Specializations of `IListRefTagImplBase` that implement the default + * implementation for `IListRefTag::Unboxed`. + */ +template +class IListRefTagImplBase { + public: + using elem_type = ListElemT; + using list_type = ArrayRef; + + /* + * These `unwrap` static methods unwraps the inner containers out + * of `IListRef` (and `IListRefIterator`). They are required when + * the macro `TORCH_ILISTREF_UNWRAP` is called. + */ + static const list_type& unwrap(const IListRef& ilist) { + return ilist.payload_.unboxed; + } + + static typename list_type::const_iterator& unwrap(IListRefIterator& it) { + return it.payload_.unboxed_iterator; + } + + static const typename list_type::const_iterator& unwrap( + const IListRefIterator& it) { + return it.payload_.unboxed_iterator; + } + + /* + * We have these function (besides the `unwrap`s above) because the + * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*` + * weren't syntatically equal for the existing tags at the time + * (`Unboxed` and `Boxed`). + */ + static IListRefConstRef front(const list_type& lst) { + return lst.front(); + } + + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + return *it; + } +}; + +/* + * Specializations of `IListRefTagImplBase` that implement the default + * implementation for `IListRefTag::Boxed`. + */ +template +class IListRefTagImplBase { + public: + using elem_type = ListElemT; + using list_type = List; + + static const list_type& unwrap(const IListRef& ilist) { + return *ilist.payload_.boxed; + } + + static typename list_type::const_iterator& unwrap(IListRefIterator& it) { + return it.payload_.boxed_iterator; + } + + static const typename list_type::const_iterator& unwrap( + const IListRefIterator& it) { + return it.payload_.boxed_iterator; + } + + static IListRefConstRef front(const list_type& lst) { + return lst[0]; + } + + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + return (*it).get().toTensor(); + } +}; + +/* + * Specializations of `IListRefTagImplBase` that implement the default + * implementation for `IListRefTag::Materialized`. + */ +template +class IListRefTagImplBase> { + public: + using elem_type = MaterializedIListRefElem; + using list_type = MaterializedIListRef; + + static const list_type& unwrap(const IListRef& ilist) { + return *ilist.payload_.materialized; + } + + static typename list_type::const_iterator& unwrap(IListRefIterator& it) { + return it.payload_.materialized_iterator; + } + + static const typename list_type::const_iterator& unwrap( + const IListRefIterator& it) { + return it.payload_.materialized_iterator; + } + + static IListRefConstRef front(const list_type& lst) { + return lst[0]; + } + + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + return *it; + } +}; + +/* + * [Note: ITensorListRef] + * Specializations necessary for `IListRef` type. + * + * Since the default implementations are usually done with supporting + * `Tensor` in mind, we only have to inherit from the base implementations. + */ +template <> +class IListRefTagImpl + : public IListRefTagImplBase {}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase {}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase< + IListRefTag::Materialized, + at::Tensor, + MaterializedIListRefElem> {}; + +/* + * [Note: IOptTensorListRef] + * Specializations necessary for `IListRef` type. + * + * We can't get an `at::OptionalTensorRef` directly from an instance of + * `List>` (the type that corresponds to the boxed world). + * + * So, the default implementation won't help us. Thus, we have to implement + * this method ourselves. + */ +template <> +class IListRefTagImpl + : public IListRefTagImplBase {}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase> { + + public: + /* + * Given an instance of the types corresponding to the `Boxed` tag, we override + * the default implementation, so that we can return a `at::OptionalTensorRef`. + */ + static IListRefConstRef iterator_get( + const typename list_type::const_iterator& it) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdangling-reference") + const auto& ivalue = (*it).get(); + C10_DIAGNOSTIC_POP() + if (!ivalue.isNone()) { + const auto& tensor = ivalue.toTensor(); + return (tensor.defined()) ? tensor : at::OptionalTensorRef{}; + } + return {}; + } +}; + +template <> +class IListRefTagImpl + : public IListRefTagImplBase< + IListRefTag::Materialized, + at::OptionalTensorRef, + MaterializedIListRefElem> {}; + +} // namespace c10::detail + + +namespace at { + +// [Note: ITensorListRef] +using ITensorListRef = c10::IListRef; +using ITensorListRefIterator = c10::IListRefIterator; +using MaterializedITensorListRef = c10::detail::MaterializedIListRef; +// [Note: IOptTensorListRef] +using IOptTensorListRef = c10::IListRef; +using IOptTensorListRefIterator = c10::IListRefIterator; +using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3efe5e0f7b87b66a1cf22ef1d5fb0f93e2f78494 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h @@ -0,0 +1,111 @@ +#pragma once + +// The legacy mechanism for dispatching operators in ATen is a Type +// object, which is essentially a giant virtual dispatch table +// for every operation we support dynamically dispatching over. +// +// This has been deprecated in favor of ATenDispatch, and in the future, +// c10 dispatcher. +// TODO: Clean up what remains here + +#include + +namespace at { + +// A RAII, thread local (!) guard that will disable dispatch to variable +// handler. +// +// NOTE [ Treating Variables as non-Variables in type dispatch ] +// +// What exactly does AutoDispatchBelowAutograd do? The short answer is, it causes +// dispatches on ATen functions to go to the non-variable implementation, +// bypassing autograd handling (and also profiling and tracing). +// +// To understand why this guard exists, it's helpful to understand the history +// behind how Variable was implemented. Previously, Variables were implemented +// as a wrapper on Tensors; so the act of processing a Variable involved +// unwrapping the underlying Tensor, and then calling the underlying base +// operation on /that/ operation +// +// However, after the Variable/Tensor merge, there is no concept of unwrapping +// a tensor anymore. If you just call the operation on the same variable +// again inside your VariableType handler, you'll dispatch back to +// VariableType, which is not what we want. +// +// The solution to the above problem is to add `at::AutoDispatchBelowAutograd`, which +// when enabled will cause `legacyTensorType()` and `getType()` to always return +// non-Variable type, even if the tensor being called on is a variable. + +/* Note [AutoDispatchBelowAutograd] + * AutoDispatchBelowAutograd is **INTERNAL ONLY** that it should be used + * for kernel implementations and customized C++ kernels. + * If you are looking for a guard to run workload in inference mode, please use + * c10::InferenceMode RAII which is user facing API. + * In the past AutoDispatchBelowAutograd(or its old version AutoNonVariableTypeMode) + * was used in the user code for inference-only workload, this was under risk of + * producing wrong results silently in some edge cases. For example: + * ``` + * torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true); + * torch::Tensor out = s * s; + * { + * at::AutoDispatchBelowAutograd guard; + * s.add_(1); // Skips version bump on `s`. + * } + * // WRONG GRADIENT! s.grad() are now computed using `s` value after the + * // inplace update. + * out.backward(torch::ones_like(out)); + * ``` + * Users should use `c10::InferenceMode` here so that it'll properly throw an + * error saying "one of the variables needed for gradient computation has be modified." + */ +struct TORCH_API AutoDispatchBelowAutograd { + AutoDispatchBelowAutograd() : + autograd_guard_(c10::autograd_dispatch_keyset) { + } + + // disable all autograd dispatch keys + c10::impl::ExcludeDispatchKeyGuard autograd_guard_; +}; + +// TODO: AutoNonVariableTypeMode should be removed in release 1.10. +struct TORCH_API AutoNonVariableTypeMode { + AutoNonVariableTypeMode(bool enabled = true) : + autograd_guard_(c10::autograd_dispatch_keyset) { + TORCH_WARN_ONCE("AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. " + "For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, " + "If you are looking for a user facing API to enable running your inference-only " + "workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code " + "is under risk of producing silent wrong result in some edge cases. " + "See Note [AutoDispatchBelowAutograd] for more details."); + TORCH_INTERNAL_ASSERT(enabled); + } + + // disable all autograd dispatch keys + c10::impl::ExcludeDispatchKeyGuard autograd_guard_; +}; + +struct TORCH_API AutoDispatchSkipFunctionalize { + AutoDispatchSkipFunctionalize() : + dispatch_key_guard_(c10::DispatchKeySet(c10::DispatchKey::Functionalize)) { + } + c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_; +}; + +/* Note [AutoDispatchBelowADInplaceOrView] + * AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode + * before we split inplace & view ops out of VariableType kernel. + * Note this guard is used in VariableType kernels for functional ops + * as well as ADInplaceOrView kernels for inplace/view ops to enforce the + * Invariant: + * Once you are in VariableType/ADInplaceOrView kernel for an op, + * you never go back to a kernel on same dispatch key until + * you finish the current op. + */ +struct TORCH_API AutoDispatchBelowADInplaceOrView { + AutoDispatchBelowADInplaceOrView() : + dispatch_key_guard_(c10::autograd_dispatch_keyset_with_ADInplaceOrView) { + } + // disable Autograd & ADInplaceOrView dispatch keys + c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_; +}; +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List.h new file mode 100644 index 0000000000000000000000000000000000000000..4cb22831947f41efc44db148d43a3395f2883843 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List.h @@ -0,0 +1,491 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +class Tensor; +} +namespace c10 { +struct IValue; +template class List; +struct Type; + +namespace detail { + +struct ListImpl final : public c10::intrusive_ptr_target { + using list_type = std::vector; + + explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_); + + list_type list; + + TypePtr elementType; + + intrusive_ptr copy() const { + return make_intrusive(list, elementType); + } + friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs); +}; +} + +namespace impl { + +template class ListIterator; + +template class ListElementReference; + +template +void swap(ListElementReference&& lhs, ListElementReference&& rhs) noexcept; + +template +bool operator==(const ListElementReference& lhs, const T& rhs); + +template +bool operator==(const T& lhs, const ListElementReference& rhs); + +template +struct ListElementConstReferenceTraits { + // In the general case, we use IValue::to(). + using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return::type; +}; + +// There is no to() overload for std::optional. +template<> +struct ListElementConstReferenceTraits> { + using const_reference = std::optional>; +}; + +template +class ListElementReference final { +public: + operator std::conditional_t< + std::is_reference_v::type>, + const T&, + T>() const; + + ListElementReference& operator=(T&& new_value) &&; + + ListElementReference& operator=(const T& new_value) &&; + + // assigning another ref to this assigns the underlying value + ListElementReference& operator=(ListElementReference&& rhs) && noexcept; + + const IValue& get() const& { + return *iterator_; + } + + friend void swap(ListElementReference&& lhs, ListElementReference&& rhs) noexcept; + + ListElementReference(const ListElementReference&) = delete; + ListElementReference& operator=(const ListElementReference&) = delete; + ~ListElementReference() = default; + +private: + ListElementReference(Iterator iter) + : iterator_(iter) {} + + // allow moving, but only our friends (i.e. the List class) can move us + ListElementReference(ListElementReference&&) noexcept = default; + ListElementReference& operator=(ListElementReference&& rhs) & noexcept { + iterator_ = std::move(rhs.iterator_); + return *this; + } + + friend class List; + friend class ListIterator; + + Iterator iterator_; +}; + +// this wraps vector::iterator to make sure user code can't rely +// on it being the type of the underlying vector. +template +class ListIterator final { + public: + // C++17 friendly std::iterator implementation + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = ListElementReference; + + explicit ListIterator() = default; + ~ListIterator() = default; + + ListIterator(const ListIterator&) = default; + ListIterator(ListIterator&&) noexcept = default; + ListIterator& operator=(const ListIterator&) = default; + ListIterator& operator=(ListIterator&&) noexcept = default; + + ListIterator& operator++() { + ++iterator_; + return *this; + } + + ListIterator operator++(int) { + ListIterator copy(*this); + ++*this; + return copy; + } + + ListIterator& operator--() { + --iterator_; + return *this; + } + + ListIterator operator--(int) { + ListIterator copy(*this); + --*this; + return copy; + } + + ListIterator& operator+=(typename List::size_type offset) { + iterator_ += offset; + return *this; + } + + ListIterator& operator-=(typename List::size_type offset) { + iterator_ -= offset; + return *this; + } + + ListIterator operator+(typename List::size_type offset) const { + return ListIterator{iterator_ + offset}; + } + + ListIterator operator-(typename List::size_type offset) const { + return ListIterator{iterator_ - offset}; + } + + friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ - rhs.iterator_; + } + + ListElementReference operator*() const { + return {iterator_}; + } + + ListElementReference operator[](typename List::size_type offset) const { + return {iterator_ + offset}; + } + +private: + explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {} + + Iterator iterator_; + + friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ == rhs.iterator_; + } + + friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) { + return !(lhs == rhs); + } + + friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ < rhs.iterator_; + } + + friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ <= rhs.iterator_; + } + + friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ > rhs.iterator_; + } + + friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) { + return lhs.iterator_ >= rhs.iterator_; + } + + friend class ListIterator; + friend class List; +}; + +template List toTypedList(List list); +template List toList(List&& list); +template List toList(const List& list); +const IValue* ptr_to_first_element(const List& list); +} + +/** + * An object of this class stores a list of values of type T. + * + * This is a pointer type. After a copy, both Lists + * will share the same storage: + * + * > List a; + * > List b = a; + * > b.push_back("three"); + * > ASSERT("three" == a.get(0)); + * + * We use this class in the PyTorch kernel API instead of + * std::vector, because that allows us to do optimizations + * and switch out the underlying list implementation without + * breaking backwards compatibility for the kernel API. + */ +template +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class List final { +private: + // This is an intrusive_ptr because List is a pointer type. + // Invariant: This will never be a nullptr, there will always be a valid + // ListImpl. + c10::intrusive_ptr impl_; + + using internal_reference_type = impl::ListElementReference; + using internal_const_reference_type = typename impl::ListElementConstReferenceTraits::const_reference; + +public: + using value_type = T; + using size_type = typename c10::detail::ListImpl::list_type::size_type; + using iterator = impl::ListIterator; + using const_iterator = impl::ListIterator; + using reverse_iterator = impl::ListIterator; + + /** + * Constructs an empty list. + */ + explicit List(); + + /** + * Constructs a list with some initial values. + * Example: + * List a({2, 3, 4}); + */ + List(std::initializer_list initial_values); + explicit List(ArrayRef initial_values); + + /** + * Create a generic list with runtime type information. + * This only works for c10::impl::GenericList and is not part of the public API + * but only supposed to be used internally by PyTorch. + */ + explicit List(TypePtr elementType); + + List(const List&) = default; + List& operator=(const List&) = default; + ~List() = default; + + /** + * Create a new List pointing to a deep copy of the same data. + * The List returned is a new list with separate storage. + * Changes in it are not reflected in the original list or vice versa. + */ + List copy() const; + + /** + * Returns the element at specified location pos, with bounds checking. + * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. + */ + internal_const_reference_type get(size_type pos) const; + + /** + * Moves out the element at the specified location pos and returns it, with bounds checking. + * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. + * The list contains an invalid element at position pos afterwards. Any operations + * on it before re-setting it are invalid. + */ + value_type extract(size_type pos) const; + + /** + * Returns a reference to the element at specified location pos, with bounds checking. + * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. + * + * You cannot store the reference, but you can read it and assign new values to it: + * + * List list = ...; + * list[2] = 5; + * int64_t v = list[1]; + */ + internal_const_reference_type operator[](size_type pos) const; + + internal_reference_type operator[](size_type pos); + + /** + * Assigns a new value to the element at location pos. + */ + void set(size_type pos, const value_type& value) const; + + /** + * Assigns a new value to the element at location pos. + */ + void set(size_type pos, value_type&& value) const; + + /** + * Returns an iterator to the first element of the container. + * If the container is empty, the returned iterator will be equal to end(). + */ + iterator begin() const; + + /** + * Returns an iterator to the element following the last element of the container. + * This element acts as a placeholder; attempting to access it results in undefined behavior. + */ + iterator end() const; + + /** + * Checks if the container has no elements. + */ + bool empty() const; + + /** + * Returns the number of elements in the container + */ + size_type size() const; + + /** + * Increase the capacity of the vector to a value that's greater or equal to new_cap. + */ + void reserve(size_type new_cap) const; + + /** + * Erases all elements from the container. After this call, size() returns zero. + * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated. + */ + void clear() const; + + /** + * Inserts value before pos. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator insert(iterator pos, const T& value) const; + + /** + * Inserts value before pos. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator insert(iterator pos, T&& value) const; + + /** + * Inserts a new element into the container directly before pos. + * The new element is constructed with the given arguments. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + template + iterator emplace(iterator pos, Args&&... value) const; + + /** + * Appends the given element value to the end of the container. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void push_back(const T& value) const; + + /** + * Appends the given element value to the end of the container. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void push_back(T&& value) const; + + /** + * Appends the given list to the end of the container. Uses at most one memory allocation. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void append(List lst) const; + + /** + * Appends the given element value to the end of the container. + * The new element is constructed with the given arguments. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + template + void emplace_back(Args&&... args) const; + + /** + * Removes the element at pos. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator erase(iterator pos) const; + + /** + * Removes the elements in the range [first, last). + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + iterator erase(iterator first, iterator last) const; + + /** + * Removes the last element of the container. + * Calling pop_back on an empty container is undefined. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void pop_back() const; + + /** + * Resizes the container to contain count elements. + * If the current size is less than count, additional default-inserted elements are appended. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void resize(size_type count) const; + + /** + * Resizes the container to contain count elements. + * If the current size is less than count, additional copies of value are appended. + * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. + */ + void resize(size_type count, const T& value) const; + + /** + * Value equality comparison. This function implements Python-like semantics for + * equality: two lists with the same identity (e.g. same pointer) trivially + * compare equal, otherwise each element is compared for equality. + */ + template + friend bool operator==(const List& lhs, const List& rhs); + + template + friend bool operator!=(const List& lhs, const List& rhs); + + /** + * Identity comparison. Returns true if and only if `rhs` represents the same + * List object as `this`. + */ + bool is(const List& rhs) const; + + std::vector vec() const; + + /** + * Returns the number of Lists currently pointing to this same list. + * If this is the only instance pointing to this list, returns 1. + */ + // TODO Test use_count + size_t use_count() const; + + TypePtr elementType() const; + + // See [unsafe set type] for why this exists. + void unsafeSetElementType(TypePtr t); + +private: + explicit List(c10::intrusive_ptr&& elements); + explicit List(const c10::intrusive_ptr& elements); + friend struct IValue; + template friend List impl::toTypedList(List); + template friend List impl::toList(List&&); + template friend List impl::toList(const List&); + friend const IValue* impl::ptr_to_first_element(const List& list); +}; + +namespace impl { +// GenericList is how IValue stores lists. It is, however, not part of the +// public API. Kernels should use Lists with concrete types instead +// (maybe except for some internal prim ops). +using GenericList = List; + +} +} + +namespace torch { + template using List = c10::List; +} + +#include // IWYU pragma: keep diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List_inl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..96f78faea22d3f57a3f01b8fb46ddaef66e8ca25 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/List_inl.h @@ -0,0 +1,353 @@ +#pragma once + +#include +#include + +namespace c10 { + +template decltype(auto) getTypePtr(); +std::string toString(const Type& type); + +template +List::List(c10::intrusive_ptr&& elements) +: impl_(std::move(elements)) {} + +template +List::List(const c10::intrusive_ptr& elements) +: impl_(elements) {} + +template +List::List() +: List(make_intrusive( + typename c10::detail::ListImpl::list_type(), + getTypePtr())) { + static_assert(!std::is_same_v, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead."); +} + +template +List::List(ArrayRef values) +: List(make_intrusive( + typename c10::detail::ListImpl::list_type(), + getTypePtr())) { + static_assert(!std::is_same_v, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); + impl_->list.reserve(values.size()); + for (const T& element : values) { + impl_->list.push_back(element); + } +} + +template +List::List(std::initializer_list initial_values) +: List(ArrayRef(initial_values)) { + static_assert(!std::is_same_v, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); +} + +template +List::List(TypePtr elementType) +: List(make_intrusive( + typename c10::detail::ListImpl::list_type(), + std::move(elementType))) { + static_assert(std::is_same_v || std::is_same_v>, + "This constructor is only valid for c10::impl::GenericList or List."); +} + +namespace impl { +template +List toTypedList(impl::GenericList list) { + // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant + // because upcasting would allow people to add types into the new list that would break the old list. + // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can + // allow upcasting. This can be a perf improvement since we can cast List to List> + // without having to copy it. This is also used to provide backwards compatibility with some old models + // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_ + // as List before we changed that argument to be List>. When deserializing, we + // have list.use_count() == 1 and can deserialize the List directly as List>. + TORCH_CHECK(*list.impl_->elementType == *getTypePtr() + || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr())) + , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr()), ">. Types mismatch."); + return List(std::move(list.impl_)); +} + +template +impl::GenericList toList(List&& list) { + return GenericList(std::move(list.impl_)); +} +template +impl::GenericList toList(const List& list) { + return GenericList(list.impl_); +} +} + +template +List List::copy() const { + return List(impl_->copy()); +} + +namespace detail { + template + T list_element_to(T element) { + return element; + } + template + T list_element_to(const IValue& element) { + return element.template to(); + } + template + T list_element_to(IValue&& element) { + return std::move(element).template to(); + } + template + struct ListElementFrom { + static IValue from(const T& element) { + return element; + } + static IValue from(T&& element) { + return std::move(element); + } + }; + template<> + struct ListElementFrom { + static const IValue& from(const IValue& element) { + return element; + } + static IValue&& from(IValue&& element) { + return std::move(element); + } + }; +} + +namespace impl { + +template +ListElementReference::operator std::conditional_t< + std::is_reference_v::type>, + const T&, + T>() const { + return iterator_->template to(); +} + +template +ListElementReference& ListElementReference::operator=(T&& new_value) && { + *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value)); + return *this; +} + +template +ListElementReference& ListElementReference::operator=(const T& new_value) && { + *iterator_ = c10::detail::ListElementFrom::from(new_value); + return *this; +} + +template +ListElementReference& ListElementReference::operator=(ListElementReference&& rhs) && noexcept { + *iterator_ = *rhs.iterator_; + return *this; +} + +template +void swap(ListElementReference&& lhs, ListElementReference&& rhs) noexcept { + std::swap(*lhs.iterator_, *rhs.iterator_); +} + +template +bool operator==(const ListElementReference& lhs, const T& rhs) { + const T& lhs_tmp = lhs; + return lhs_tmp == rhs; +} + +template +inline bool operator==(const T& lhs, const ListElementReference& rhs) { + return rhs == lhs; +} + +template +inline typename ListElementConstReferenceTraits::const_reference +list_element_to_const_ref(const IValue& element) { + return element.template to(); +} + +template<> +inline typename ListElementConstReferenceTraits>::const_reference +list_element_to_const_ref>(const IValue& element) { + return element.toOptionalStringRef(); +} + +} // namespace impl + +template +void List::set(size_type pos, const value_type& value) const { + impl_->list.at(pos) = c10::detail::ListElementFrom::from(value); +} + +template +void List::set(size_type pos, value_type&& value) const { + impl_->list.at(pos) = c10::detail::ListElementFrom::from(std::move(value)); +} + +template +typename List::internal_const_reference_type List::get(size_type pos) const { + return operator[](pos); +} + +template +typename List::internal_const_reference_type List::operator[](size_type pos) const { + return c10::impl::list_element_to_const_ref(impl_->list.at(pos)); +} + +template +typename List::internal_reference_type List::operator[](size_type pos) { + static_cast(impl_->list.at(pos)); // Throw the exception if it is out of range. + return {impl_->list.begin() + static_castlist)::difference_type>(pos)}; +} + +template +typename List::value_type List::extract(size_type pos) const { + auto& elem = impl_->list.at(pos); + auto result = c10::detail::list_element_to(std::move(elem)); + // Reset the list element to a T() instead of None to keep it correctly typed + elem = c10::detail::ListElementFrom::from(T{}); + return result; +} + +template +typename List::iterator List::begin() const { + return iterator(impl_->list.begin()); +} + +template +typename List::iterator List::end() const { + return iterator(impl_->list.end()); +} + +template +bool List::empty() const { + return impl_->list.empty(); +} + +template +typename List::size_type List::size() const { + return impl_->list.size(); +} + +template +void List::reserve(size_type new_cap) const { + impl_->list.reserve(new_cap); +} + +template +void List::clear() const { + impl_->list.clear(); +} + +template +typename List::iterator List::insert(iterator pos, const T& value) const { + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(value)) }; +} + +template +typename List::iterator List::insert(iterator pos, T&& value) const { + return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(std::move(value))) }; +} + +template +template +typename List::iterator List::emplace(iterator pos, Args&&... value) const { + // TODO Use list_element_from? + return iterator { impl_->list.emplace(pos.iterator_, std::forward(value)...) }; +} + +template +void List::push_back(const T& value) const { + impl_->list.push_back(c10::detail::ListElementFrom::from(value)); +} + +template +void List::push_back(T&& value) const { + impl_->list.push_back(c10::detail::ListElementFrom::from(std::move(value))); +} + +template +void List::append(List b) const { + if (b.use_count() == 1) { + impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end())); + } else { + impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end()); + } +} + +template +template +void List::emplace_back(Args&&... args) const { + // TODO Use list_element_from? + impl_->list.push_back(T(std::forward(args)...)); +} + +template +typename List::iterator List::erase(iterator pos) const { + return iterator { impl_->list.erase(pos.iterator_) }; +} + +template +typename List::iterator List::erase(iterator first, iterator last) const { + return iterator { impl_->list.erase(first.iterator_, last.iterator_) }; +} + +template +void List::pop_back() const { + impl_->list.pop_back(); +} + +template +void List::resize(size_type count) const { + impl_->list.resize(count, T{}); +} + +template +void List::resize(size_type count, const T& value) const { + impl_->list.resize(count, value); +} + +template +bool operator==(const List& lhs, const List& rhs) { + // Lists with the same identity trivially compare equal. + if (lhs.impl_ == rhs.impl_) { + return true; + } + + // Otherwise, just compare values directly. + return *lhs.impl_ == *rhs.impl_; +} + +template +bool operator!=(const List& lhs, const List& rhs) { + return !(lhs == rhs); +} + +template +bool List::is(const List& rhs) const { + return this->impl_ == rhs.impl_; +} + +template +std::vector List::vec() const { + std::vector result(begin(), end()); + return result; +} + +template +size_t List::use_count() const { + return impl_.use_count(); +} + +template +TypePtr List::elementType() const { + return impl_->elementType; +} + +template +void List::unsafeSetElementType(TypePtr t) { + impl_->elementType = std::move(t); +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/MT19937RNGEngine.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/MT19937RNGEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..5bcc4c0c15d967aedc4841052c89d0a18abe6fcd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/MT19937RNGEngine.h @@ -0,0 +1,194 @@ +#pragma once + +#include + +// define constants like M_PI and C keywords for MSVC +#ifdef _MSC_VER +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include +#endif + +#include +#include +#include + +namespace at { + +constexpr int MERSENNE_STATE_N = 624; +constexpr int MERSENNE_STATE_M = 397; +constexpr uint32_t MATRIX_A = 0x9908b0df; +constexpr uint32_t UMASK = 0x80000000; +constexpr uint32_t LMASK = 0x7fffffff; + +/** + * Note [Mt19937 Engine implementation] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Originally implemented in: + * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c + * and modified with C++ constructs. Moreover the state array of the engine + * has been modified to hold 32 bit uints instead of 64 bits. + * + * Note that we reimplemented mt19937 instead of using std::mt19937 because, + * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2 + * by default and following are the benchmark numbers (benchmark code can be found at + * https://github.com/syed-ahmed/benchmark-rngs): + * + * with -O2 + * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s + * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s + * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s + * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s + * + * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution, + * however we can't use std::uniform_real_distribution because of this bug: + * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used + * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm + * than what's in pytorch currently and that messes up the tests in tests_distributions.py. + * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower + * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter. + * + * Copyright notice: + * A C-program for MT19937, with initialization improved 2002/2/10. + * Coded by Takuji Nishimura and Makoto Matsumoto. + * This is a faster version by taking Shawn Cokus's optimization, + * Matthe Bellew's simplification, Isaku Wada's real version. + * + * Before using, initialize the state by using init_genrand(seed) + * or init_by_array(init_key, key_length). + * + * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. The names of its contributors may not be used to endorse or promote + * products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * + * Any feedback is very welcome. + * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html + * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space) + */ + +/** + * mt19937_data_pod is used to get POD data in and out + * of mt19937_engine. Used in torch.get_rng_state and + * torch.set_rng_state functions. + */ +struct mt19937_data_pod { + uint64_t seed_; + int left_; + bool seeded_; + uint32_t next_; + std::array state_; +}; + +class mt19937_engine { +public: + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + inline explicit mt19937_engine(uint64_t seed = 5489) { + init_with_uint32(seed); + } + + inline mt19937_data_pod data() const { + return data_; + } + + inline void set_data(const mt19937_data_pod& data) { + data_ = data; + } + + inline uint64_t seed() const { + return data_.seed_; + } + + inline bool is_valid() { + if ((data_.seeded_ == true) + && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N) + && (data_.next_ <= MERSENNE_STATE_N)) { + return true; + } + return false; + } + + inline uint32_t operator()() { + if (--(data_.left_) == 0) { + next_state(); + } + uint32_t y = *(data_.state_.data() + data_.next_++); + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680; + y ^= (y << 15) & 0xefc60000; + y ^= (y >> 18); + + return y; + } + +private: + mt19937_data_pod data_; + + inline void init_with_uint32(uint64_t seed) { + data_.seed_ = seed; + data_.seeded_ = true; + data_.state_[0] = seed & 0xffffffff; + for (const auto j : c10::irange(1, MERSENNE_STATE_N)) { + data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j); + } + data_.left_ = 1; + data_.next_ = 0; + } + + inline uint32_t mix_bits(uint32_t u, uint32_t v) { + return (u & UMASK) | (v & LMASK); + } + + inline uint32_t twist(uint32_t u, uint32_t v) { + return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0); + } + + inline void next_state() { + uint32_t* p = data_.state_.data(); + data_.left_ = MERSENNE_STATE_N; + data_.next_ = 0; + + for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) { + *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]); + } + + for(int j = MERSENNE_STATE_M; --j; p++) { + *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]); + } + + *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]); + } + +}; + +typedef mt19937_engine mt19937; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NamedTensor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NamedTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..81998e160185aba0561404e637ff20a2d18d4869 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NamedTensor.h @@ -0,0 +1,143 @@ +#pragma once + +#include +#include + +namespace at { + +class TensorBase; + +// XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen. +// Due to the c10/ATen library split, TensorImpl cannot depend on Dimname, +// so we have a couple of workarounds. +// +// In the long term, we'll move Dimname to c10 and everything in this file +// can be refactored out. The main blocker for that is that "c10::Symbol" +// actually exists outside of c10 and needs to be moved in. + +// TensorImpl has a unique_ptr field. +// XXX: Ideally we would just put std::optional> into TensorImpl. +// +// This class has an important invariant: there must be at least ONE +// non-wildcard +struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface { + // This enum is to remind people that the invariant on constructors is that + // the list of dimnames must have at least one non-wildcard + enum HAS_NON_WILDCARD { + HasNonWildcard + }; + + explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names) + : names_(names.vec()) { + check_invariants(); + } + explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector&& names) + : names_(std::move(names)) { + check_invariants(); + } + + std::unique_ptr clone() const override { + return std::make_unique(HasNonWildcard, names_); + } + + DimnameList names() const { return names_; } + + // Used for an assertion in TensorImpl.h + int64_t slow_dim() const override { + return static_cast(names_.size()); + } + + void check_invariants() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); })); + } + + void set_names(HAS_NON_WILDCARD, DimnameList new_names) { + TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); + std::copy(new_names.begin(), new_names.end(), names_.begin()); + check_invariants(); + } + + void set_names(HAS_NON_WILDCARD, std::vector&& new_names) { + TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); + names_ = std::move(new_names); + check_invariants(); + } + + // INVARIANT: at least one Dimname is non-WILDCARD + std::vector names_; +}; + +// When NamesMode is disabled, then all operations ignore tensors' names fields. +// Concretely speaking, all tensors are treated as having nullopt names. +struct TORCH_API NamesMode { + static bool is_enabled(); + static void set_enabled(bool enabled); +}; + + +// A RAII, thread local (!) guard that enables or disables names upon +// construction, and sets it back to the original value upon destruction. +struct TORCH_API NoNamesGuard { + NoNamesGuard() : prev_mode(NamesMode::is_enabled()) { + NamesMode::set_enabled(false); + } + NoNamesGuard(const NoNamesGuard&) = delete; + NoNamesGuard(NoNamesGuard&&) = delete; + NoNamesGuard& operator=(const NoNamesGuard&) = delete; + NoNamesGuard& operator=(NoNamesGuard&&) = delete; + ~NoNamesGuard() { + if (initialized) { + reset(); + } + } + void reset() { + TORCH_INTERNAL_ASSERT(initialized); + NamesMode::set_enabled(prev_mode); + } + private: + bool prev_mode; + bool initialized{true}; +}; + +void check_names_valid_for(const TensorBase& tensor, DimnameList names); +void check_names_valid_for(size_t tensor_dim, DimnameList names); + +// Sets the names of `tensor` to be `names`. +TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional names); +TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector&& names, bool validate_names); + +constexpr size_t kMaxNamedTensorDim = 64; + +DimnameList default_names(size_t len); + +namespace impl { + +// Some helper functions on TensorImpl. Useful for working with names in TH. +// XXX: Ideally these would exist as methods on TensorImpl +TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional names, bool validate_names); +TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); + +void check_names_valid_for(TensorImpl* impl, DimnameList names); + +// Returns true if the tensor's names exist and are not all 'None'. +// Returns false if the tensor's names don't exist (were not allocated), +// or if all names are 'None'. +// We treat not-allocated-names the same as allocated names that are all 'None'. +TORCH_API bool has_names(const TensorImpl* impl); + +// Returns the names of the tensor's dimensions. +// Unnamed tensors are treated as having 'None' in all dimension; this method +// would return a DimnameList of all 'None's for an unnamed tensor. +TORCH_API DimnameList get_names(const TensorImpl* impl); + +// This is more of an implementation detail; one should use impl::get_names / +// Tensor::names() whenever possible because it provides a cleaner API. +// Returns the names of the tensor if they have been allocated; returns nullopt +// instead if the haven't been. The names of a tensor are not allocated if a +// tensor is constructed with names=None. +TORCH_API std::optional get_opt_names(const TensorImpl* impl); + +} // namespace impl + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..23ae67f25cc17242bad288c60f7bd3e3a2af412a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h @@ -0,0 +1,187 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// The motivating usecase for this is to represent the ragged size structure +// of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This +// allows us to simply return [B, j0, D] if someone queries for the size of our +// tensor. +// +// Morally we define comparison between two nested ints to return true if +// that comparison holds for all corresponding elements of the arrays they +// represent. Comparison between a nested int and a plain int is defined +// similarly. +// +// To simulate this desired behavior but also avoid the O(N) cost of checking, +// we associate each raggedness pattern with an integer "id" that can be used as +// a proxy to evaluate equality. We also constrain the range of values for this +// as to enable inequality checks. +// +// We also support a positive integer scalar "coeff" that is used for computing +// strides. For example given, a [B, j0, D] tensor, it can be strided in two +// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to +// differentiate the two cases. +// +// During tracing the strides of the outputs need to be a function of the size +// and strides of the inputs so it is important that NestedIntSymNode itself is +// able to express this. +class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl { + public: + // CAUTION: you should probably not be constructing these directly; please + // the higher-level API in python instead (TODO: actually introduce that). + explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff) + : val_(val), coeff_(coeff) {} + + bool bool_() override { + return false; + } + + bool is_int() override { + return true; + } + + bool is_float() override { + return false; + } + + bool is_bool() override { + return false; + } + + bool is_nested_int() const override { + return true; + } + + bool has_hint() override { + return true; + } + + c10::SymNode wrap_int(int64_t num) override { + return SymNode(c10::make_intrusive>(num)); + } + + int64_t guard_int(const char* file, int64_t line) override { + TORCH_CHECK(false); + } + + double guard_float(const char* file, int64_t line) override { + TORCH_CHECK(false, "not a float"); + } + + bool guard_bool(const char* file, int64_t line) override { + TORCH_CHECK(false, "not a bool"); + } + + int64_t int_() override { + TORCH_CHECK(false); + } + + std::string str() override { + if (coeff_ == 1) { + return "j" + std::to_string(val_); + } + return std::to_string(coeff_) + "*j" + std::to_string(val_); + } + + // NOTE [ Inequalities with nested int ] + // + // The semantics of nested int when it comes to relations is that it is + // treated as integer known to be within a certain range, + // + // j0 \in [2, int64_t::max] + // + // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False). + // This is a useful default range for the raggedness pattern of a jagged + // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1 + // specialization checks. + // + // [ Indeterminate inequalities error out ] + // + // Given the semantic defined above, certain relations like j0 < 3 are thus + // indeterminable. In our impl today, evaluating such relations error + // + // It may seem convenient to just define indeterminate relations to return + // False, but the implementation we maintain in parallel using sympy does not + // allow this. + // + // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are, + // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This + // would mean that means that if we define the indeterminate j0 >= 3 to be + // False, the also indeterminate j0 < 3 will be evaluated to be True! + // + // [ Coefficient are assumed positive ] + // + // For the purpose of computing inequalities, we consider the coefficient of + // the nested int to be a positive integer. + // + // Thus, no modifications are needed to the logic since + // j0 >= k implies coeff * j0 >= k + // + c10::SymNode eq(const c10::SymNode& other) override; + c10::SymNode ne(const c10::SymNode& other) override; + c10::SymNode ge(const c10::SymNode& other) override; + c10::SymNode gt(const c10::SymNode& other) override; + c10::SymNode lt(const c10::SymNode& other) override; + c10::SymNode le(const c10::SymNode& other) override; + c10::SymNode mul(const c10::SymNode& other) override; + + std::optional nested_int() override { + return val_; + } + + std::optional nested_int_coeff() override { + return coeff_; + } + + bool is_symbolic() override { + return false; + } + + c10::SymNode clone() override; + +#define DEFINE_BINARY_NOT_SUPPORTED(name) \ + c10::SymNode name(const c10::SymNode& other) override { \ + TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \ + } + + DEFINE_BINARY_NOT_SUPPORTED(add) + DEFINE_BINARY_NOT_SUPPORTED(sub) + DEFINE_BINARY_NOT_SUPPORTED(truediv) + DEFINE_BINARY_NOT_SUPPORTED(pow) + DEFINE_BINARY_NOT_SUPPORTED(floordiv) + DEFINE_BINARY_NOT_SUPPORTED(mod) + DEFINE_BINARY_NOT_SUPPORTED(sym_min) + DEFINE_BINARY_NOT_SUPPORTED(sym_max) + DEFINE_BINARY_NOT_SUPPORTED(sym_and) + DEFINE_BINARY_NOT_SUPPORTED(sym_or) + +#undef DEFINE_BINARY_NOT_SUPPORTED + +#define DEFINE_NOT_SUPPORTED(name) \ + c10::SymNode name() override { \ + TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \ + } + + DEFINE_NOT_SUPPORTED(sym_not) + DEFINE_NOT_SUPPORTED(ceil) + DEFINE_NOT_SUPPORTED(floor) + DEFINE_NOT_SUPPORTED(neg) + DEFINE_NOT_SUPPORTED(sym_float) + +#undef DEFINE_NOT_SUPPORTED + + private: + int64_t val_; + int64_t coeff_; +}; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..413055d3fad6555582e9b7ccf80e38387a52d535 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h @@ -0,0 +1,240 @@ +#pragma once + +// define constants like M_PI and C keywords for MSVC +#ifdef _MSC_VER +#define _USE_MATH_DEFINES +#include +#endif + + +#ifdef __CUDACC__ +#include +#endif + +#include +#include +#include +#include + +namespace at { + +// typedefs for holding vector data +namespace detail { + +typedef std::array UINT4; +typedef std::array UINT2; +typedef std::array DOUBLE2; +typedef std::array FLOAT2; + +} // namespace detail + +/** + * Note [Philox Engine implementation] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Originally implemented in PyTorch's fusion compiler + * Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + * for details regarding the engine. + * + * Note that currently this implementation of the philox engine is not used + * anywhere except for tests in cpu_generator_test.cpp. However, this engine + * will replace curandStatePhilox4_32_10_t in the future. + * + * The philox engine takes a seed value, a subsequeunce + * for starting the generation and an offset for the subsequence. + * Think of this engine as an algorithm producing a huge array. We are + * parallelizing this array by partitioning the huge array and assigning + * a thread index to each partition. In other words, each seed value + * (there are 2^64 possible seed values) gives a sub array of size + * 2^128 (each element in that array is a 128 bit number). Reasoning + * behind the array being of size 2^128 is, there are 2^64 possible + * thread index value and there is an array of size 2^64 for each of + * those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value. + * + * In short, this generator can produce 2^64 (seed values) * 2^128 (number + * of elements in an array given by a seed value) = 2^192 values. + * + * Arguments: + * seed: Seed values could be any number from 0 to 2^64-1. + * subsequence: Subsequence is just the cuda thread indexing with: + * - blockIdx.x * blockDim.x + threadIdx.x + * offset: The offset variable in PhiloxEngine decides how many 128-bit + * random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip) + * and hence really decides the total number of randoms that can be achieved + * for the given subsequence. + */ + +class philox_engine { +public: + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721, + uint64_t subsequence = 0, + uint64_t offset = 0) { + + reset_state(seed, subsequence); + incr_n(offset); + } + + C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721, + uint64_t subsequence = 0) { + key_[0] = static_cast(seed); + key_[1] = static_cast(seed >> 32); + counter_ = detail::UINT4{}; + counter_[2] = static_cast(subsequence); + counter_[3] = static_cast(subsequence >> 32); + STATE = 0; + } + + /** + * Set the offset field of Philox Generator to the desired offset. + */ + C10_HOST_DEVICE inline void set_offset(uint64_t offset) { + counter_[0] = static_cast(offset); + counter_[1] = static_cast(offset >> 32); + } + + /** + * Gets the current offset of the Philox Generator. + */ + C10_HOST_DEVICE uint64_t get_offset() const { + uint64_t lo = static_cast(counter_[0]); + uint64_t hi = static_cast(counter_[1]) << 32; + return lo | hi; + } + + /** + * Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste. + */ + C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior + if(STATE == 0) { + detail::UINT4 counter = counter_; + detail::UINT2 key = key_; + output_ = rand(counter, key, n_rounds); + incr(); + } + uint32_t ret = output_[static_cast(STATE)]; + STATE = (STATE + 1) & 3; + return ret; + } + + inline float randn(uint32_t n_rounds) { + #ifdef __CUDA_ARCH__ + AT_ASSERT(false, "Unsupported invocation of randn on CUDA"); + #endif + if(STATE == 0) { + detail::UINT4 counter = counter_; + detail::UINT2 key = key_; + output_ = rand(counter, key, n_rounds); + incr(); + } + // TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method + // TODO(voz) We use std:: below, and thus need a separate impl for CUDA. + float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log. + float u2 = 1 - uint32_to_uniform_float(output_[1]); + return static_cast(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2)); + } + + /** + * Function that Skips N 128 bit numbers in a subsequence + */ + C10_HOST_DEVICE inline void incr_n(uint64_t n) { + uint32_t nlo = static_cast(n); + uint32_t nhi = static_cast(n >> 32); + counter_[0] += nlo; + // if overflow in x has occurred, carry over to nhi + if (counter_[0] < nlo) { + nhi++; + // if overflow in nhi has occurred during carry over, + // propagate that overflow to y and exit to increment z + // otherwise return + counter_[1] += nhi; + if(nhi != 0) { + if (nhi <= counter_[1]) { + return; + } + } + } else { + // if overflow in y has occurred during addition, + // exit to increment z + // otherwise return + counter_[1] += nhi; + if (nhi <= counter_[1]) { + return; + } + } + if (++counter_[2]) + return; + ++counter_[3]; + } + + /** + * Function that Skips one 128 bit number in a subsequence + */ + C10_HOST_DEVICE inline void incr() { + if (++counter_[0]) + return; + if (++counter_[1]) + return; + if (++counter_[2]) { + return; + } + ++counter_[3]; + } + +private: + detail::UINT4 counter_; + detail::UINT4 output_; + detail::UINT2 key_; + uint32_t STATE; + + C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b, + uint32_t *result_high) { + #ifdef __CUDA_ARCH__ + *result_high = __umulhi(a, b); + return a*b; + #else + const uint64_t product = static_cast(a) * b; + *result_high = static_cast(product >> 32); + return static_cast(product); + #endif + } + + C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) { + uint32_t hi0 = 0; + uint32_t hi1 = 0; + uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0); + uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1); + detail::UINT4 ret; + ret[0] = hi1 ^ ctr[1] ^ in_key[0]; + ret[1] = lo1; + ret[2] = hi0 ^ ctr[3] ^ in_key[1]; + ret[3] = lo0; + return ret; + } + + C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) { + // maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + constexpr float scale = 4.6566127342e-10; + return static_cast(value & 0x7FFFFFFF) * scale; + } + + + + C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) { + for (uint32_t round = 0; round < (n_rounds - 1); round++) { + counter = single_round(counter, key); + key[0] += (kPhilox10A); key[1] += (kPhilox10B); + } + return single_round(counter, key); + } + + + static const uint32_t kPhilox10A = 0x9E3779B9; + static const uint32_t kPhilox10B = 0xBB67AE85; + static const uint32_t kPhiloxSA = 0xD2511F53; + static const uint32_t kPhiloxSB = 0xCD9E8D57; +}; + +typedef philox_engine Philox4_32; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonFallbackKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonFallbackKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1d2b613166d3f308ea3114b0726a37ad113460a8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonFallbackKernel.h @@ -0,0 +1,35 @@ +#pragma once +#include + + +namespace at::impl { + +struct TORCH_API RestorePythonTLSSnapshot { + RestorePythonTLSSnapshot(); + RestorePythonTLSSnapshot(RestorePythonTLSSnapshot&& other) = delete; + RestorePythonTLSSnapshot(const RestorePythonTLSSnapshot&) = delete; + RestorePythonTLSSnapshot& operator=(const RestorePythonTLSSnapshot&) = delete; + RestorePythonTLSSnapshot& operator=(RestorePythonTLSSnapshot&&) = delete; + ~RestorePythonTLSSnapshot(); + +private: + c10::impl::LocalDispatchKeySet saved_; + c10::impl::ForceDispatchKeyGuard guard_; +}; + + +// RAII guard to make working with the above TLS safer. +struct TORCH_API MaybeSetTLSOnEntryGuard { +public: + MaybeSetTLSOnEntryGuard(); + MaybeSetTLSOnEntryGuard(MaybeSetTLSOnEntryGuard&& other) = delete; + MaybeSetTLSOnEntryGuard(const MaybeSetTLSOnEntryGuard&) = delete; + MaybeSetTLSOnEntryGuard& operator=(const MaybeSetTLSOnEntryGuard&) = delete; + MaybeSetTLSOnEntryGuard& operator=(MaybeSetTLSOnEntryGuard&&) = delete; + ~MaybeSetTLSOnEntryGuard(); + +private: + bool value_set_; +}; + +} // namespace at::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h new file mode 100644 index 0000000000000000000000000000000000000000..bec323c7d25bf0143dae83edc3ac1da3efd8eb48 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +// TODO: this can probably live in c10 + + +namespace at::impl { + +class TORCH_API PythonOpRegistrationTrampoline final { + static std::atomic interpreter_; + +public: + // Returns true if you successfully registered yourself (that means + // you are in the hot seat for doing the operator registrations!) + static bool registerInterpreter(c10::impl::PyInterpreter*); + + // Returns nullptr if no interpreter has been registered yet. + static c10::impl::PyInterpreter* getInterpreter(); +}; + +} // namespace at::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/QuantizerBase.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/QuantizerBase.h new file mode 100644 index 0000000000000000000000000000000000000000..a56ead7a30c69623f4c893a1c237a3f4fb617212 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/QuantizerBase.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include + +namespace at { + +class Tensor; +struct QTensorImpl; +struct Quantizer; +using ConstQuantizerPtr = const c10::intrusive_ptr&; +using QuantizerPtr = c10::intrusive_ptr; + +/** + * Quantizer is the class for storing all the information + * that's necessary to perform quantize and dequantize + * operation. + * + * We might have different types of quantization schemes and this is + * the base class for all quantizers. + * + * QTensorImpl will hold a pointer to Quantizer so that we can support + * different quantization schemes on Tensor. + * + * For example, the most common quantization scheme, Affine Quantization, + * requires scale and zero_point as parameters, we'll store scale and zero_point + * inside the instance and we can use it to quantize a float Tensor or + * dequantize a quantized Tensor. + * + * When you add new types of leaf Quantizer class, please also + * make sure to add a corresponding QScheme enum since + * they should have one to one mapping. + * + * Note about intrusive_ptr: + * Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can + * share the same Quantizer. Quantizer should be immutable. + */ +struct TORCH_API Quantizer : public c10::intrusive_ptr_target { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const ScalarType scalar_type_; + explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {} + ~Quantizer() override = default; + + // Copied from torch/csrc/jit/ir/scope.h + QuantizerPtr intrusive_from_this() { + c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer + // from a raw `this` pointer + // so we need to bump the refcount + // to account for this ownership + return c10::intrusive_ptr::reclaim(this); + } + + /** + * Each concrete Quantizer type should have a unique QScheme type. + */ + virtual QScheme qscheme() const = 0; + + ScalarType scalar_type() const { + return scalar_type_; + } + + /** + * quantize a float Tensor into a quantized Tensor. + */ + virtual Tensor quantize(const Tensor& t) = 0; + + /** + * dequantize a quantized Tensor into a float Tensor. + */ + virtual Tensor dequantize(const Tensor& t) = 0; + + /** + * dequantize a quantized Tensor into a float Tensor, out= variant + */ + virtual Tensor& dequantize_out(Tensor& out, const Tensor& t) = 0; + + /** + * Compare against `other` for equality. + */ + virtual bool equalTo(QuantizerPtr other) const = 0; +}; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Range.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Range.h new file mode 100644 index 0000000000000000000000000000000000000000..2bf6b2b73ac4d4c178ac0388e9b45e262e506b86 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Range.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace at { + +struct Range { + Range(int64_t begin, int64_t end) + : begin(begin) + , end(end) {} + + int64_t size() const { return end - begin; } + + Range operator/(int64_t divisor) { + return Range(begin / divisor, end / divisor); + } + + int64_t begin; + int64_t end; +}; + +std::ostream& operator<<(std::ostream& out, const Range& range); + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Reduction.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..340e9f91ae8f7fd10919148d2632149f5642fd7d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Reduction.h @@ -0,0 +1,14 @@ +#pragma once + +namespace at::Reduction { + +// NB: Keep this in sync with Reduction class in torch/nn/_reduction.py +// These constants control the reduction behavior of loss functions. +// Ideally, this would be a scoped enum, but jit doesn't support that +enum Reduction { + None, // Do not reduce + Mean, // (Possibly weighted) mean of losses + Sum, // Sum losses + END +}; +} // namespace at::Reduction diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Scalar.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Scalar.h new file mode 100644 index 0000000000000000000000000000000000000000..a14b48f0120cbfc23c45db14dd363b0c88c59a2c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Scalar.h @@ -0,0 +1 @@ +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ScalarType.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ScalarType.h new file mode 100644 index 0000000000000000000000000000000000000000..eb30ee86f737a3544e727c42f72147fdb64a5e3b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ScalarType.h @@ -0,0 +1 @@ +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Tensor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..96ef0ee4d8636c4e2bc4b163dc7db235ab5fac2d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Tensor.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +namespace at { +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class TORCH_API OptionalTensorRef { + public: + OptionalTensorRef() = default; + + ~OptionalTensorRef() { + ref_.unsafeReleaseTensorImpl(); + } + + OptionalTensorRef(const TensorBase& src) + : ref_(Tensor::unsafe_borrow_t{}, src) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined()); + } + + OptionalTensorRef(const OptionalTensorRef& rhs) + : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {} + + OptionalTensorRef(OptionalTensorRef&& rhs) = default; + OptionalTensorRef& operator=(OptionalTensorRef rhs) { + std::swap(ref_, rhs.ref_); + return *this; + } + + bool has_value() const { + return ref_.defined(); + } + + const Tensor& getTensorRef() const & { + return ref_; + } + + const Tensor& operator*() const & { + return ref_; + } + + const Tensor* operator->() const & { + return &ref_; + } + + operator bool() const { + return ref_.defined(); + } + + private: + Tensor ref_; +}; + +// Use to convert a TensorBase (that may be undefined) to an at::Tensor +// without bumping refcount. +class TORCH_API TensorRef { + public: + ~TensorRef() { + ref_.unsafeReleaseTensorImpl(); + } + + TensorRef(const TensorBase& src) + : ref_(Tensor::unsafe_borrow_t{}, src) {} + TensorRef(TensorRef&& other) = default; + TensorRef(const TensorRef&) = default; + TensorRef& operator=(const TensorRef&) = default; + TensorRef& operator=(TensorRef&&) = default; + + const Tensor& operator*() const & { + return ref_; + } + private: + Tensor ref_; +}; + +template +auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t { + // Return the grad argument in case of a hook with void return type to have an + // std::function with Tensor return type + static_assert(std::is_same_v, + "Expected hook to return void"); + return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { + TensorRef grad(grad_base); + fn(*grad); + return Tensor(); + }); +} + +template +auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t { + return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { + TensorRef grad(grad_base); + Tensor ret = fn(*grad); + return TensorBase(std::move(ret)); + }); +} + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorAccessor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..8cf57d2b646fe134d93aadd83bd96380b1242bb9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorAccessor.h @@ -0,0 +1,275 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor +// is used to enable the __restrict__ keyword/modifier for the data +// passed to cuda. +template +struct DefaultPtrTraits { + typedef T* PtrType; +}; + +#if defined(__CUDACC__) || defined(__HIPCC__) +template +struct RestrictPtrTraits { + typedef T* __restrict__ PtrType; +}; +#endif + +// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors. +// For CUDA tensors it is used in device code (only). This means that we restrict ourselves +// to functions and types available there (e.g. IntArrayRef isn't). + +// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers. +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class TensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessorBase( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : data_(data_), sizes_(sizes_), strides_(strides_) {} + C10_HOST IntArrayRef sizes() const { + return IntArrayRef(sizes_,N); + } + C10_HOST IntArrayRef strides() const { + return IntArrayRef(strides_,N); + } + C10_HOST_DEVICE index_t stride(index_t i) const { + return strides_[i]; + } + C10_HOST_DEVICE index_t size(index_t i) const { + return sizes_[i]; + } + C10_HOST_DEVICE PtrType data() { + return data_; + } + C10_HOST_DEVICE const PtrType data() const { + return data_; + } +protected: + PtrType data_; + const index_t* sizes_; + const index_t* strides_; +}; + +// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using +// `Tensor.accessor()`. +// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only +// indexing on the device uses `TensorAccessor`s. +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class TensorAccessor : public TensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_,sizes_,strides_) {} + + C10_HOST_DEVICE TensorAccessor operator[](index_t i) { + return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); + } + + C10_HOST_DEVICE const TensorAccessor operator[](index_t i) const { + return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); + } +}; + +template class PtrTraits, typename index_t> +class TensorAccessor : public TensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_,sizes_,strides_) {} + C10_HOST_DEVICE T & operator[](index_t i) { + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + return this->data_[this->strides_[0]*i]; + } + C10_HOST_DEVICE const T & operator[](index_t i) const { + return this->data_[this->strides_[0]*i]; + } +}; + + +// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host +// and as +// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host) +// in order to transfer them on the device when calling kernels. +// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s. +// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__. +// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available +// on the device, so those functions are host only. +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class GenericPackedTensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + C10_HOST GenericPackedTensorAccessorBase( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : data_(data_) { + std::copy(sizes_, sizes_ + N, std::begin(this->sizes_)); + std::copy(strides_, strides_ + N, std::begin(this->strides_)); + } + + // if index_t is not int64_t, we want to have an int64_t constructor + template >> + C10_HOST GenericPackedTensorAccessorBase( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : data_(data_) { + for (const auto i : c10::irange(N)) { + this->sizes_[i] = sizes_[i]; + this->strides_[i] = strides_[i]; + } + } + + C10_HOST_DEVICE index_t stride(index_t i) const { + return strides_[i]; + } + C10_HOST_DEVICE index_t size(index_t i) const { + return sizes_[i]; + } + C10_HOST_DEVICE PtrType data() { + return data_; + } + C10_HOST_DEVICE const PtrType data() const { + return data_; + } +protected: + PtrType data_; + // NOLINTNEXTLINE(*c-arrays*) + index_t sizes_[N]; + // NOLINTNEXTLINE(*c-arrays*) + index_t strides_[N]; + C10_HOST void bounds_check_(index_t i) const { + TORCH_CHECK_INDEX( + 0 <= i && i < index_t{N}, + "Index ", + i, + " is not within bounds of a tensor of dimension ", + N); + } +}; + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + // if index_t is not int64_t, we want to have an int64_t constructor + template >> + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + C10_DEVICE TensorAccessor operator[](index_t i) { + index_t* new_sizes = this->sizes_ + 1; + index_t* new_strides = this->strides_ + 1; + return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); + } + + C10_DEVICE const TensorAccessor operator[](index_t i) const { + const index_t* new_sizes = this->sizes_ + 1; + const index_t* new_strides = this->strides_ + 1; + return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); + } + + /// Returns a PackedTensorAccessor of the same dimension after transposing the + /// two dimensions given. Does not actually move elements; transposition is + /// made by permuting the size/stride arrays. If the dimensions are not valid, + /// asserts. + C10_HOST GenericPackedTensorAccessor transpose( + index_t dim1, + index_t dim2) const { + this->bounds_check_(dim1); + this->bounds_check_(dim2); + GenericPackedTensorAccessor result( + this->data_, this->sizes_, this->strides_); + std::swap(result.strides_[dim1], result.strides_[dim2]); + std::swap(result.sizes_[dim1], result.sizes_[dim2]); + return result; + } +}; + +template class PtrTraits, typename index_t> +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { +public: + typedef typename PtrTraits::PtrType PtrType; + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const index_t* sizes_, + const index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + // if index_t is not int64_t, we want to have an int64_t constructor + template >> + C10_HOST GenericPackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} + + C10_DEVICE T & operator[](index_t i) { + return this->data_[this->strides_[0] * i]; + } + C10_DEVICE const T& operator[](index_t i) const { + return this->data_[this->strides_[0]*i]; + } + + // Same as in the general N-dimensional case, but note that in the + // 1-dimensional case the returned PackedTensorAccessor will always be an + // identical copy of the original + C10_HOST GenericPackedTensorAccessor transpose( + index_t dim1, + index_t dim2) const { + this->bounds_check_(dim1); + this->bounds_check_(dim2); + return GenericPackedTensorAccessor( + this->data_, this->sizes_, this->strides_); + } +}; + + +// Can't put this directly into the macro function args because of commas +#define AT_X GenericPackedTensorAccessor + +// Old name for `GenericPackedTensorAccessor` +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X) + +#undef AT_X + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor32 = GenericPackedTensorAccessor; + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor64 = GenericPackedTensorAccessor; +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBase.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBase.h new file mode 100644 index 0000000000000000000000000000000000000000..3dc2ec66f8168eb156bcb15e84d2e56adafe2b83 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBase.h @@ -0,0 +1,1056 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { +class Scalar; +} + +namespace torch::autograd { + +struct Node; + +} // namespace torch::autograd + +namespace at { + +class Tensor; +class TensorBase; + +// Convert Tensor to TensorBase without any need to include Tensor.h +TORCH_API const TensorBase& get_tensor_base(const Tensor& t); + +namespace impl { +inline bool variable_excluded_from_dispatch() { +#ifdef C10_MOBILE + // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change. + return true; +#else + return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset); +#endif +} + +} + +// NOTE: [Tensor vs. TensorBase] +// +// Tensor, being the central data structure in PyTorch, gets used and +// its header included almost everywhere. Unfortunately this means +// every time an operator signature is updated or changed in +// native_functions.yaml, you (and every other PyTorch developer) need +// to recompile all of ATen and its dependencies. +// +// TensorBase aims to break up these header dependencies, and improve +// incremental build times for all PyTorch developers. TensorBase +// represents a reference counted handle to TensorImpl, exactly the +// same as Tensor. However, TensorBase doesn't have code generated +// methods in its API and thus no dependence on native_functions.yaml. +// +// Usage tips +// ---------- +// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp +// or .cu file to ensure it has no header dependencies on +// native_functions.yaml (direct or indirect). +// - Tensor inherits from TensorBase, so functions taking +// `const TensorBase &` are callable with Tensor as well. +// - TensorBase can be converted to Tensor with `Tensor(tensor_base)`, +// but this requires a reference-count bump. OptionalTensorRef, on +// the other hand, can materialize a `const Tensor &` without +// touching the reference-count. +class TORCH_API TensorBase { + public: + struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; }; + + protected: + // Create a Tensor with a +0 reference count. Special care must be + // taken to avoid decrementing this reference count at destruction + // time. Intended to support MaybeOwnedTraits. + explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs) + : impl_(c10::intrusive_ptr(rhs.impl_.get(), c10::raw::DontIncreaseRefcount{})) {} + friend MaybeOwnedTraits; + + public: + TensorBase() = default; + // This constructor should not be used by end users and is an implementation + // detail invoked by autogenerated code. + explicit TensorBase( + c10::intrusive_ptr tensor_impl) + : impl_(std::move(tensor_impl)) { + if (impl_.get() == nullptr) { + throw std::runtime_error("TensorImpl with nullptr is not supported"); + } + } + TensorBase(const TensorBase&) = default; + TensorBase(TensorBase&&) noexcept = default; + ~TensorBase() noexcept = default; + + public: + // Creates a new wrapper from TensorImpl. Intentionally a free method because + // it should be used with care. Checks necessary invariants + static TensorBase wrap_tensor_impl( + c10::intrusive_ptr tensor_impl) { + TensorBase r(std::move(tensor_impl)); + r.enforce_invariants(); + return r; + } + + int64_t dim() const { + return impl_->dim(); + } + int64_t storage_offset() const { + return impl_->storage_offset(); + } + + TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + if (is_contiguous(memory_format)) { + return *this; + } else { + return __dispatch_contiguous(memory_format); + } + } + + /// Should be used if *this can reasonably be expected to be contiguous and + /// performance is important. + /// Compared to contiguous, it saves a reference count + /// increment/decrement if *this is already contiguous, at the cost + /// in all cases of an extra pointer of stack usage, an extra branch + /// to access, and an extra branch at destruction time. + c10::MaybeOwned expect_contiguous( + MemoryFormat memory_format=MemoryFormat::Contiguous) const &; + + // Use .contiguous() instead. Trying to borrow from a prvalue + // will only lead to trouble and dangling references. + c10::MaybeOwned expect_contiguous( + MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; + + const TensorBase& fill_(const c10::Scalar& scalar) const; + const TensorBase& zero_() const; + + TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional memory_format=std::nullopt) const; + + bool is_complex() const { + return at::isComplexType(this->scalar_type()); + } + + bool is_floating_point() const { + return at::isFloatingType(this->scalar_type()); + } + + bool is_signed() const { + return at::isSignedType(this->scalar_type()); + } + + c10::SymInt sym_size(int64_t dim) const { + return impl_->sym_size(dim); + } + + c10::SymInt sym_stride(int64_t dim) const { + const auto sizes = this->sym_strides(); + const auto ndim = static_cast(sizes.size()); + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; + + } + + int64_t size(int64_t dim) const { + return impl_->size(dim); + } + + int64_t stride(int64_t dim) const { + const auto strides = this->strides(); + const auto ndim = static_cast(strides.size()); + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)]; + } + + TensorImpl * unsafeGetTensorImpl() const { + return impl_.get(); + } + TensorImpl * unsafeReleaseTensorImpl() { + return impl_.release(); + } + const c10::intrusive_ptr& getIntrusivePtr() const { + return impl_; + } + + c10::intrusive_ptr unsafeReleaseIntrusivePtr() { + return std::move(impl_); + } + + bool defined() const { + return impl_; + } + + void reset() { + impl_.reset(); + } + +#if defined (_MSC_VER) + TensorBase& operator=(const TensorBase& x) & { + impl_ = x.impl_; + return *this; + }; + TensorBase& operator=(TensorBase&& x) & noexcept { + impl_ = std::move(x.impl_); + return *this; + } +#else + TensorBase& operator=(const TensorBase& x) & = default; + TensorBase& operator=(TensorBase&& x) & noexcept = default; +#endif + + // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here + TensorBase& operator=(const TensorBase&) && = delete; + TensorBase& operator=(TensorBase&&) && noexcept = delete; + + bool is_same(const TensorBase& other) const noexcept { + return impl_ == other.impl_; + } + size_t use_count() const noexcept { + return impl_.use_count(); + } + size_t weak_use_count() const noexcept { + return impl_.weak_use_count(); + } + + std::string toString() const; + + IntArrayRef sizes() const { + return impl_->sizes(); + } + c10::SymIntArrayRef sym_sizes() const { + return impl_->sym_sizes(); + } + c10::SymIntArrayRef sym_strides() const { + return impl_->sym_strides(); + } + IntArrayRef strides() const { + return impl_->strides(); + } + // See impl::get_opt_names in ATen/NamedTensor.h for docs. + std::optional opt_names() const { + return impl::get_opt_names(unsafeGetTensorImpl()); + } + // See impl::get_names in ATen/NamedTensor.h for docs. + DimnameList names() const { + return impl::get_names(unsafeGetTensorImpl()); + } + int64_t ndimension() const { + return dim(); + } + + bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { + return impl_->is_contiguous(memory_format); + } + + bool is_non_overlapping_and_dense() const { + return impl_->is_non_overlapping_and_dense(); + } + + at::MemoryFormat suggest_memory_format( + bool channels_last_strides_exact_match = false) const { + // Setting channels_last_strides_exact_match to true forces function to + // check 0,1 - sized dimension strides. + if (layout() == at::kStrided) { + if (impl_->is_strides_like_channels_last()) { + if (!channels_last_strides_exact_match || + get_channels_last_strides_2d(sizes()) == strides()) { + return at::MemoryFormat::ChannelsLast; + } + } + else if (impl_->is_strides_like_channels_last_3d()) { + if (!channels_last_strides_exact_match || + get_channels_last_strides_3d(sizes()) == strides()) { + return at::MemoryFormat::ChannelsLast3d; + } + } + } + return at::MemoryFormat::Contiguous; + } + + // Total bytes consumed by the "view" of elements of the array. Does not + // include size of metadata. The number reported here does not necessarily + // correspond to the true physical memory consumed by a tensor; instead, + // it reports the memory the tensor would take *if* it were contiguous. + // Defined to be numel() * itemsize() + size_t nbytes() const { + TORCH_CHECK(layout () != at::kSparse, + "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ + "tensors, add the nbytes of the indices and values. If you want the size of the " \ + "equivalent dense tensor, multiply numel() by element_size()"); + return impl_->numel() * impl_->itemsize(); + } + + c10::SymInt sym_nbytes() const { + TORCH_CHECK(layout () != at::kSparse, + "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ + "tensors, add the nbytes of the indices and values. If you want the size of the " \ + "equivalent dense tensor, multiply numel() by element_size()"); + return impl_->sym_numel() * impl_->itemsize(); + } + + int64_t numel() const { + return impl_->numel(); + } + + c10::SymInt sym_numel() const { + return impl_->sym_numel(); + } + + c10::SymInt sym_storage_offset() const { + return impl_->sym_storage_offset(); + } + + // Length of one array element in bytes. This is the traditional + // Numpy naming. + size_t itemsize() const { + return impl_->itemsize(); + } + + // Same as itemsize(). This is the PyTorch naming. + int64_t element_size() const { + return static_cast(impl_->itemsize()); + } + + DispatchKeySet key_set() const { + return impl_->key_set(); + } + ScalarType scalar_type() const { + return typeMetaToScalarType(impl_->dtype()); + } + bool has_storage() const { + return defined() && impl_->has_storage(); + } + const Storage& storage() const { + return impl_->storage(); + } + bool is_alias_of(const at::TensorBase& other) const{ + return impl_->storage().is_alias_of(other.storage()); + } + + // Move the storage backend to shm based + // to enable memory sharing across processes. + // + // NB1: the ideal behavior of this API still requires further discussion + // but for now we are inclined to keep it consistent with existing THP behavior + // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212 + // so we don't assert on anything here and rely on caller knowing + // what it's doing. + // + // NB2: this currently provides Linux fd based shm support only + // to simplify the storage lifetime management logic in ATen + // and similarly for now we are not adding support for file system based + // shm support like in THP due to additional GC manager support needed + // to prevent leaks. + // As such, calling this from non supported systems (e.g. Windows) would fail. + void share_memory_() { + at::share_memory_(*this); + } + + inline bool _is_zerotensor() const { + return impl_->_is_zerotensor(); + } + + inline void _set_zero(bool zero) const { + impl_->_set_zero(zero); + } + + inline bool is_conj() const { + return impl_->is_conj(); + } + + // sets the conjugate bit of a tensor. + // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure + // that's what you want. Changing this might lead to incorrect behavior since conjugation is + // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized. + inline void _set_conj(bool conjugate) const { + impl_->_set_conj(conjugate); + } + + inline bool is_neg() const { + return impl_->is_neg(); + } + + // sets the negative bit of a tensor. + // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure + // that's what you want. Changing this might lead to incorrect behavior since we rely on this + // bit to determine if a negation needs to be materialized. + inline void _set_neg(bool negative) const { + impl_->_set_neg(negative); + } + + /// Returns a `Tensor`'s layout. + Layout layout() const { + return impl_->layout(); + } + + /// Returns a `Tensor`'s dtype (`TypeMeta`). + caffe2::TypeMeta dtype() const { + return impl_->dtype(); + } + + /// Returns a `Tensor`'s device. + inline Device device() const { + return impl_->device(); + } + + /// Returns a `Tensor`'s device index. + DeviceIndex get_device() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->get_device(); + } + + /// Returns if a `Tensor` has CPU backend. + bool is_cpu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_cpu(); + } + + /// Returns if a `Tensor` has CUDA backend. + bool is_cuda() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_cuda(); + } + + /// Returns if a `Tensor` has IPU backend. + bool is_ipu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_ipu(); + } + + /// Returns if a `Tensor` has XPU backend. + bool is_xpu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_xpu(); + } + + /// Returns if a `Tensor` has XLA backend. + bool is_xla() const { + return impl_->is_xla(); + } + + /// Returns if a `Tensor` has MTIA backend. + bool is_mtia() const { + return impl_->is_mtia(); + } + + /// Returns if a `Tensor` has HPU backend. + bool is_hpu() const { + return impl_->is_hpu(); + } + + /// Returns if a `Tensor` has Lazy backend. + bool is_lazy() const { + return impl_->is_lazy(); + } + + /// Returns if a `Tensor` has HIP backend. + bool is_hip() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_hip(); + } + + /// Returns if a `Tensor` has VE backend. + bool is_ve() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_ve(); + } + + /// Returns if a `Tensor` has PrivateUse1 backend. + bool is_privateuseone() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_privateuseone(); + } + + /// Returns if a `Tensor` has sparse backend. + bool is_sparse() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_sparse(); + } + + /// Returns is a `Tensor` has a sparse CSR backend. + bool is_sparse_csr() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_sparse_csr(); + } + + /// Returns if a `Tensor` is mkldnn tensor. + bool is_mkldnn() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_mkldnn(); + } + + /// Returns if a `Tensor` is mps tensor. + bool is_mps() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_mps(); + } + + /// Returns if a `Tensor` is maia tensor. + bool is_maia() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_maia(); + } + + /// Returns if a `Tensor` is vulkan tensor. + bool is_vulkan() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_vulkan(); + } + + /// Returns if a `Tensor` is metal tensor. + bool is_metal() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_metal(); + } + + /// Returns if a `Tensor` has quantized backend. + bool is_quantized() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_quantized(); + } + + /// Returns if a `Tensor` is a meta tensor. Meta tensors can + /// also have other designations. + bool is_meta() const { + return impl_->is_meta(); + } + + /// Returns if a `Tensor` is an inference tensor. + bool is_inference() const { + return impl_->is_inference(); + } + + // Returns if a `Tensor` is a NestedTensor. + bool is_nested() const { + return impl_->is_nested(); + } + + /// If a tensor is a quantized tensor, returns its quantizer + /// TODO: it's not in native_functions.yaml yet as it's not exposed to python + QuantizerPtr quantizer() const; + + /// Returns if a `Tensor` has any dimension names + bool has_names() const { + // If a user is using unnamed tensors, then we can short-circuit right here. + // Otherwise, impl::has_names attempts to retrieve names. + if (!impl_->has_named_tensor_meta()) { + return false; + } + return impl::has_names(unsafeGetTensorImpl()); + } + + /// Returns a `Tensor`'s dimension names data structure + const NamedTensorMeta* get_named_tensor_meta() const { + return static_cast(impl_->named_tensor_meta()); + } + + NamedTensorMeta* get_named_tensor_meta() { + return static_cast(impl_->named_tensor_meta()); + } + + /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in + /// TensorOptions.h. + TensorOptions options() const { + return TensorOptions().dtype(dtype()) + .device(device()) + .layout(layout()); + } + + const void* const_data_ptr() const { + return this->unsafeGetTensorImpl()->data(); + } + + void* mutable_data_ptr() const { + return this->unsafeGetTensorImpl()->mutable_data(); + } + + // TODO(#97856) Make this return a const pointer. This currently + // returns a non-const pointer because of the large + // number of clients that we still want to audit before + // migrating to mutable_data_ptr(). + void* data_ptr() const { + return mutable_data_ptr(); + } + + template , int> = 0> + const T* const_data_ptr() const; + + template , int> = 0> + const std::remove_const_t* const_data_ptr() const; + + template + T* mutable_data_ptr() const; + + // Legacy interface during the migration to indicate that a callsite + // has not been audited for mutability. + // + // Do not add new uses of this, use const_data_ptr() if possible, + // mutable_data_ptr() otherwise. + // + // TODO(#97856) Make this return a const pointer. This is currently + // const because of the vast number of clients that + // rely on this. + template + T* data_ptr() const; + + // Purposely not defined here to avoid inlining + void print() const; + + // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and + // dimension. + template + TensorAccessor accessor() const& { + static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); + T* ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = const_data_ptr(); + } else { + ptr = mutable_data_ptr(); + } + return TensorAccessor(ptr,sizes().data(),strides().data()); + } + template + TensorAccessor accessor() && = delete; + + // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and + // dimension. You can optionally specify RestrictPtrTraits as a template parameter to + // cast the data pointer to a __restrict__ pointer. + // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor + // as an argument. + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() const& { + static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); + T* ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = const_data_ptr(); + } else { + ptr = mutable_data_ptr(); + } + return GenericPackedTensorAccessor(static_cast::PtrType>(ptr),sizes().data(),strides().data()); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() const& { + TORCH_CHECK( + impl_->numel() <= + static_cast(std::numeric_limits::max()), + "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() && = delete; + + // ~~~~~ Autograd API ~~~~~ + + /// \fn bool is_leaf() const; + /// + /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention. + /// + /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were + /// created by the user. This means that they are not the result of an operation and so + /// `grad_fn()` is `nullptr`. + /// + /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`. + /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`. + /// + /// Example: + /// @code + /// auto a = torch::rand(10, torch::requires_grad()); + /// std::cout << a.is_leaf() << std::endl; // prints `true` + /// + /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA); + /// std::cout << b.is_leaf() << std::endl; // prints `false` + /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor + /// + /// auto c = torch::rand(10, torch::requires_grad()) + 2; + /// std::cout << c.is_leaf() << std::endl; // prints `false` + /// // c was created by the addition operation + /// + /// auto d = torch::rand(10).cuda(); + /// std::cout << d.is_leaf() << std::endl; // prints `true` + /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + /// + /// auto e = torch::rand(10).cuda().requires_grad_(); + /// std::cout << e.is_leaf() << std::endl; // prints `true` + /// // e requires gradients and has no operations creating it + /// + /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true)); + /// std::cout << f.is_leaf() << std::endl; // prints `true` + /// // f requires grad, has no operation creating it + /// @endcode + + /// \fn void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const; + /// + /// Computes the gradient of current tensor with respect to graph leaves. + /// + /// The graph is differentiated using the chain rule. If the tensor is + /// non-scalar (i.e. its data has more than one element) and requires + /// gradient, the function additionally requires specifying ``gradient``. + /// It should be a tensor of matching type and location, that contains + /// the gradient of the differentiated function w.r.t. this Tensor. + /// + /// This function accumulates gradients in the leaves - you might need to + /// zero them before calling it. + /// + /// \param gradient Gradient w.r.t. the + /// tensor. If it is a tensor, it will be automatically converted + /// to a Tensor that does not require grad unless ``create_graph`` is True. + /// None values can be specified for scalar Tensors or ones that + /// don't require grad. If a None value would be acceptable then + /// this argument is optional. + /// \param retain_graph If ``false``, the graph used to compute + /// the grads will be freed. Note that in nearly all cases setting + /// this option to True is not needed and often can be worked around + /// in a much more efficient way. Defaults to the value of + /// ``create_graph``. + /// \param create_graph If ``true``, graph of the derivative will + /// be constructed, allowing to compute higher order derivative + /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. + /// When inputs are provided and a given input is not a leaf, + /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). + /// It is an implementation detail on which the user should not rely. + /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + + /// \fn Tensor detach() const; + /// + /// Returns a new Tensor, detached from the current graph. + /// The result will never require gradient. + + /// \fn Tensor & detach_() const; + /// + /// Detaches the Tensor from the graph that created it, making it a leaf. + /// Views cannot be detached in-place. + + /// \fn void retain_grad() const; + /// + /// Enables this Tensor to have their :attr:`grad` populated during + /// :func:`backward`. This is a no-op for leaf tensors. + + /// \fn bool retains_grad() const; + /// + /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be + /// populated during :func:`backward`, ``false`` otherwise. + + const TensorBase& set_requires_grad(bool requires_grad) const { + impl_->set_requires_grad(requires_grad); + return *this; + } + bool requires_grad() const { + return impl_->requires_grad(); + } + + // The Forward AD API functions below are low level and are not to be used by end + // users who should use the API provided in torch/csrc/autograd.h + + /// This function returns the forward gradient for this Tensor at the given level. + const Tensor& _fw_grad(uint64_t level) const { + return impl_->_fw_grad(level, *this); + } + + /// This function can be used to set the value of the forward grad. + /// Note that the given new_grad might not be used directly if it has different + /// metadata (size/stride/storage offset) compared to this Tensor. In that case, + /// new_grad content will be copied into a new Tensor + void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const { + impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op); + } + + /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended + /// to be used from functions that need to access the `Variable`'s equivalent `Tensor` + /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`). + /// + /// One notable difference with the legacy `.data()` function is that changes to the + /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) + /// will not update the original `Variable`, due to the fact that this function + /// shallow-copies the `Variable`'s underlying TensorImpl. + at::TensorBase tensor_data() const; + + /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` + /// in Python, which create a new `Variable` that shares the same storage and + /// tensor metadata with the original `Variable`, but with a completely new + /// autograd history. + /// + /// NOTE: If we change the tensor metadata (e.g. sizes / strides / + /// storage / storage_offset) of a variable created from `var.variable_data()`, those + /// changes will not update the original variable `var`. In `.variable_data()`, we set + /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, + /// in order to prevent users from changing metadata of `var.variable_data()` + /// and expecting the original variable `var` to also be updated. + at::TensorBase variable_data() const; + + // Gradient Node and Edges + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// Gets the gradient function of the `Variable`. If this is a leaf variable, + /// the pointer returned will be null. + /// + /// For View Variables: + /// Gets the up-to-date grad_fn. If the shared data or base was modified, we + /// re-create the grad_fn to express the up-to-date view relationship between + /// this and the base Variable. + const std::shared_ptr& grad_fn() const; + + // Hooks + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + using hook_return_void_t = std::enable_if_t>, unsigned>; + template + using hook_return_var_t = std::enable_if_t, TensorBase>, unsigned>; + + /// Registers a backward hook. + /// + /// The hook will be called every time a gradient with respect to the Tensor is computed. + /// The hook should have one of the following signature: + /// ``` + /// hook(TensorBase grad) -> TensorBase + /// ``` + /// ``` + /// hook(TensorBase grad) -> void + /// ``` + /// The hook should not modify its argument, but it can optionally return a new gradient + /// which will be used in place of `grad`. + /// + /// This function returns the index of the hook in the list which can be used to remove hook. + /// + /// Example: + /// @code + /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); + /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient + /// v.backward(torch::tensor({1., 2., 3.})); + /// // This prints: + /// // ``` + /// // 2 + /// // 4 + /// // 6 + /// // [ CPUFloatType{3} ] + /// // ``` + /// std::cout << v.grad() << std::endl; + /// v.remove_hook(h); // removes the hook + /// @endcode + template + hook_return_void_t register_hook(T&& hook) const; + template + hook_return_var_t register_hook(T&& hook) const; + +protected: + unsigned _register_hook(std::function hook) const; + +public: + + /// Remove hook at given position + void remove_hook(unsigned pos) const; + + // Variable methods + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + bool is_leaf() const; + + int64_t output_nr() const; + + void set_data(const TensorBase & new_data) const; + + TensorBase data() const; + + int64_t _version() const; + + void retain_grad() const; + + bool retains_grad() const; + + const TensorBase& requires_grad_(bool _requires_grad=true) const; + + // View Variables + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// Returns true if this `Variable` is a view of another `Variable`. + bool is_view() const; + + /// Returns the `Variable` that this `Variable` is a view of. If this + /// `Variable` is not a view, throw a `std::runtime_error`. + const TensorBase& _base() const; + + // Miscellaneous + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + const std::string& name() const; + +protected: + void enforce_invariants(); + c10::intrusive_ptr impl_; + +private: + TensorBase __dispatch_contiguous(c10::MemoryFormat) const; +}; + +inline DeviceIndex get_device(const TensorBase& self) { + return self.get_device(); +} + +template +auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t { + // Return the grad argument in case of a hook with void return type to have an + // std::function with Tensor return type + static_assert(std::is_same_v, + "Expected hook to return void"); + return _register_hook([fn=std::forward(hook)](const TensorBase& grad) { + fn(grad); + return TensorBase(); + }); +} + +template +auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t { + return _register_hook(std::forward(hook)); +} + +namespace detail { +// Helper creator for Tensor class which doesn't requires the users to pass +// in an intrusive_ptr instead it just converts the argument passed to +// requested intrusive_ptr type. +template +TensorBase make_tensor_base(Args&&... args) { + return TensorBase(c10::make_intrusive(std::forward(args)...)); +} + +} // namespace detail + +inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) { + return legacyExtractDispatchKey(t.key_set()); +} + +} // namespace at + +namespace c10 { +template <> +struct MaybeOwnedTraits { + using owned_type = at::TensorBase; + using borrow_type = at::TensorBase; + + static borrow_type createBorrow(const owned_type& from) { + // NOTE: this can be implemented without the special + // unsafe_borrow_t Tensor constructor as + // + // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl())); + // + // but that hurts inlining due to the nullptr check in the + // Tensor(c10::intrusive_ptr<...>) constructor. We already know + // that from.impl_ isn't null because from is a valid Tensor, so + // we needn't do the check again. (using __builtin_assume can + // avoid this, but wouldn't be portable to MSVC.) + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseTensorImpl(); + // See above note: this can be implemented with public API + // similarly to createBorrow(), but that would hurt inlining. + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits : public c10::ExclusivelyOwnedTensorTraits {}; +} // namespace c10 + +namespace at { + +inline c10::MaybeOwned borrow_from_optional_tensor( + const std::optional& opt) { + return opt.has_value() + ? c10::MaybeOwned::borrowed(*opt) + : c10::MaybeOwned::owned(std::in_place); +} + +inline c10::MaybeOwned TensorBase::expect_contiguous(MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); + } +} + +namespace symint { + +template +using enable_if_symint = std::enable_if_t>; +template +using enable_if_int = std::enable_if_t>; + +template > +c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); } +template > +IntArrayRef sizes(const TensorBase& t) { return t.sizes(); } + +template > +c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); } +template > +int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); } + +template > +c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); } +template > +IntArrayRef strides(const TensorBase& t) { return t.strides(); } + +template > +c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); } +template > +int64_t numel(const TensorBase& t) { return t.numel(); } + +} // namespace symint + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBody.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBody.h new file mode 100644 index 0000000000000000000000000000000000000000..84a21cd70bc7f777fc26780339f9074800003569 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TensorBody.h @@ -0,0 +1,5761 @@ +#pragma once + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + +namespace c10{ +template class List; +template class IListRef; +} +namespace at { +struct Generator; +struct Type; +class DeprecatedTypeProperties; +class Tensor; +} // namespace at +namespace at { +namespace indexing { +struct TensorIndex; +} // namespace indexing +} // namespace at + +namespace torch { namespace autograd { + +struct Node; + +}} // namespace torch::autograd + +namespace at { + +class OptionalTensorRef; +class TensorRef; +class Tensor; +using TensorList = ArrayRef; +using ITensorList = c10::IListRef; + +using Stream = c10::Stream; + +// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which +// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr. +// +// For example: +// +// void func(Tensor a) { +// Tensor b = a; +// ... +// } +// +// In this example, when we say Tensor b = a, we are creating a new object that points to the +// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the +// destructor decrements the reference count by calling release() on the TensorImpl it points to. +// The existing constructors, operator overloads, etc. take care to implement the correct semantics. +// +// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and +// special care must be taken to handle this. +class TORCH_API Tensor: public TensorBase { + protected: + // Create a Tensor with a +0 reference count. Special care must be + // taken to avoid decrementing this reference count at destruction + // time. Intended to support MaybeOwnedTraits. + explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {} + friend MaybeOwnedTraits; + friend OptionalTensorRef; + friend TensorRef; + + public: + Tensor() = default; + // This constructor should not be used by end users and is an implementation + // detail invoked by autogenerated code. + explicit Tensor( + c10::intrusive_ptr tensor_impl) + : TensorBase(std::move(tensor_impl)) {} + Tensor(const Tensor &tensor) = default; + Tensor(Tensor &&tensor) = default; + + // Implicitly move-constructible from TensorBase, but must be explicit to increase refcount + explicit Tensor(const TensorBase &base): TensorBase(base) {} + /*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {} + + // Creates a new wrapper from TensorImpl. Intentionally a free method because + // it should be used with care. Checks necessary invariants + static Tensor wrap_tensor_impl( + c10::intrusive_ptr tensor_impl) { + return TensorBase::wrap_tensor_impl(std::move(tensor_impl)); + } + + Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + return TensorBase::contiguous(memory_format); + } + + Tensor conj() const { + if (!this->is_complex()) { + return *this; + } + + switch (this->layout()) { + case at::kSparse: + case at::kSparseCsr: + case at::kSparseCsc: + case at::kSparseBsr: + case at::kSparseBsc: + return this->conj_physical(); + default: + return this->_conj(); + } + } + + // Aliased by Dimname overloads, so need explicit using + using TensorBase::size; + using TensorBase::sym_size; + using TensorBase::stride; + + /// Should be used if *this can reasonably be expected to be contiguous and + /// performance is important. + /// Compared to contiguous, it saves a reference count + /// increment/decrement if *this is already contiguous, at the cost + /// in all cases of an extra pointer of stack usage, an extra branch + /// to access, and an extra branch at destruction time. + c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &; + + // Use .contiguous() instead. Trying to borrow from a prvalue Tensor + // will only lead to trouble and dangling references. + c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; + + // The following overloads are very intruiging. Consider the following + // program: + // + // x[1] = 3; + // + // We would expect that the first entry of x is written to 3. But how can we + // actually achieve this? x[1] evaluates to a tensor... + // + // The answer is, using a ref-qualifier. x[1] is an rvalue, which cannot be + // (profitably) assigned to in the traditional sense, so we overload + // assignment to mean, "Actually, copy 3 into the tensor data." This is done + // with an rvalue-reference ref-qualified overload (the methods with && at the + // end of their type.) + // + // There's one more fly in the ointment: We also want + // + // Tensor x = y; + // + // to work, and we want it NOT to copy. So we need a traditional operator= + // overload. But we MUST specify a mutable lvalue ref-qualifier, to + // disambiguate the traditional overload from the rvalue-reference + // ref-qualified overload. Otherwise, it will be ambiguous, because + // a non ref-qualified method is eligible for all situations. + + // Unfortunately, we have to write these constructors out manually + // to work around an MSVC bug: + // error C2580: 'at::Tensor &at::Tensor::operator =(const at::Tensor &) &': + // multiple versions of a defaulted special member functions are not allowed + // Tensor& operator=(const Tensor&) & = default; + // Tensor& operator=(Tensor&&) & = default; + + // Also MSVC will wrongly issue the following warning with the aforementioned fix + // warning C4522: 'at::Tensor': multiple assignment operators specified + // Let's just skip the warning. + // + // TODO: temporarily disabled + + Tensor& operator=(const TensorBase& x) & noexcept { + impl_ = x.getIntrusivePtr(); + return *this; + } + Tensor& operator=(TensorBase&& x) & noexcept { + impl_ = x.unsafeReleaseIntrusivePtr(); + return *this; + } + + Tensor& operator=(const Tensor &x) & noexcept { + return operator=(static_cast(x)); + } + Tensor& operator=(Tensor &&x) & noexcept { + return operator=(static_cast(x)); + } + + Tensor& operator=(const Scalar &v) && { + return fill_(v); + } + Tensor& operator=(const Tensor &rhs) && { + return copy_(rhs); + } + Tensor& operator=(Tensor&& rhs) && { + return copy_(rhs); + } + + C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().") + DeprecatedTypeProperties & type() const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + dispatchKeyToBackend(legacyExtractDispatchKey(key_set())), + scalar_type()); + } + + Tensor toType(ScalarType t) const { + return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false); + } + + // TODO: Deprecate me + Tensor toBackend(Backend b) const { + return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false); + } + + C10_DEPRECATED_MESSAGE("Tensor.is_variable() is deprecated; everything is a variable now. (If you want to assert that variable has been appropriately handled already, use at::impl::variable_excluded_from_dispatch())") + bool is_variable() const noexcept { + return !at::impl::variable_excluded_from_dispatch(); + } + + template + C10_DEPRECATED_MESSAGE("Tensor.data() is deprecated. Please use Tensor.data_ptr() instead.") + T * data() const { + return data_ptr(); + } + + template + T item() const; + + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() const & { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() && = delete; + + Tensor operator~() const { + return bitwise_not(); + } + Tensor operator-() const { + return neg(); + } + Tensor& operator+=(const Tensor & other) { + return add_(other); + } + Tensor& operator+=(const Scalar & other) { + return add_(other); + } + Tensor& operator-=(const Tensor & other) { + return sub_(other); + } + Tensor& operator-=(const Scalar & other) { + return sub_(other); + } + Tensor& operator*=(const Tensor & other) { + return mul_(other); + } + Tensor& operator*=(const Scalar & other) { + return mul_(other); + } + Tensor& operator/=(const Tensor & other) { + return div_(other); + } + Tensor& operator/=(const Scalar & other) { + return div_(other); + } + Tensor& operator&=(const Tensor & other) { + return bitwise_and_(other); + } + Tensor& operator|=(const Tensor & other) { + return bitwise_or_(other); + } + Tensor& operator^=(const Tensor & other) { + return bitwise_xor_(other); + } + Tensor operator[](const Scalar & index) const { + if (!index.isIntegral(false)) { + TORCH_CHECK_INDEX(false, "Can only index tensors with integral scalars"); + } + return this->operator[](index.toLong()); + } + Tensor operator[](const Tensor & index) const { + // These properties are checked in the Scalar constructor, but we already + // check them here to provide more useful diagnostics for the user. + if (!index.defined()) { + TORCH_CHECK_INDEX(false, "Can only index with tensors that are defined"); + } + if (index.dim() != 0) { + TORCH_CHECK_INDEX(false, + "Can only index with tensors that are scalars (zero-dim)"); + } + // The Scalar(Tensor) constructor is explicit, so we need to call it. + return this->operator[](index.item()); + } + Tensor operator[](int64_t index) const { + return select(0, index); + } + + Tensor index(ArrayRef indices) const; + Tensor index(std::initializer_list indices) const; + + Tensor & index_put_(ArrayRef indices, Tensor const & rhs); + Tensor & index_put_(ArrayRef indices, const Scalar& v); + Tensor & index_put_(std::initializer_list indices, Tensor const & rhs); + Tensor & index_put_(std::initializer_list indices, const Scalar& v); + + Tensor cpu() const { + return to(options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); + } + + // TODO: The Python version also accepts arguments + Tensor cuda() const { + return to(options().device(c10::DeviceType::CUDA), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor hip() const { + return to(options().device(c10::DeviceType::HIP), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor ve() const { + return to(options().device(c10::DeviceType::VE), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor vulkan() const { + return to(options().device(c10::DeviceType::Vulkan), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor metal() const { + return to(options().device(c10::DeviceType::Metal), /*non_blocking*/ false, /*copy*/ false); + } + + Tensor meta() const { + return to(options().device(c10::DeviceType::Meta), /*non_blocking*/ false, /*copy*/ false); + } + + // ~~~~~ Autograd API ~~~~~ + + /// \fn bool is_leaf() const; + /// + /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention. + /// + /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were + /// created by the user. This means that they are not the result of an operation and so + /// `grad_fn()` is `nullptr`. + /// + /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`. + /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`. + /// + /// Example: + /// @code + /// auto a = torch::rand(10, torch::requires_grad()); + /// std::cout << a.is_leaf() << std::endl; // prints `true` + /// + /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA); + /// std::cout << b.is_leaf() << std::endl; // prints `false` + /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor + /// + /// auto c = torch::rand(10, torch::requires_grad()) + 2; + /// std::cout << c.is_leaf() << std::endl; // prints `false` + /// // c was created by the addition operation + /// + /// auto d = torch::rand(10).cuda(); + /// std::cout << d.is_leaf() << std::endl; // prints `true` + /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + /// + /// auto e = torch::rand(10).cuda().requires_grad_(); + /// std::cout << e.is_leaf() << std::endl; // prints `true` + /// // e requires gradients and has no operations creating it + /// + /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true)); + /// std::cout << f.is_leaf() << std::endl; // prints `true` + /// // f requires grad, has no operation creating it + /// @endcode + + /// \fn void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const; + /// + /// Computes the gradient of current tensor with respect to graph leaves. + /// + /// The graph is differentiated using the chain rule. If the tensor is + /// non-scalar (i.e. its data has more than one element) and requires + /// gradient, the function additionally requires specifying ``gradient``. + /// It should be a tensor of matching type and location, that contains + /// the gradient of the differentiated function w.r.t. this Tensor. + /// + /// This function accumulates gradients in the leaves - you might need to + /// zero them before calling it. + /// + /// \param gradient Gradient w.r.t. the + /// tensor. If it is a tensor, it will be automatically converted + /// to a Tensor that does not require grad unless ``create_graph`` is True. + /// None values can be specified for scalar Tensors or ones that + /// don't require grad. If a None value would be acceptable then + /// this argument is optional. + /// \param retain_graph If ``false``, the graph used to compute + /// the grads will be freed. Note that in nearly all cases setting + /// this option to True is not needed and often can be worked around + /// in a much more efficient way. Defaults to the value of + /// ``create_graph``. + /// \param create_graph If ``true``, graph of the derivative will + /// be constructed, allowing to compute higher order derivative + /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. + /// When inputs are provided and a given input is not a leaf, + /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). + /// It is an implementation detail on which the user should not rely. + /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const { + // NB: Adding this wrapper to _backward here because we'd like our + // 'backwards' api to accept the 'inputs' argument optionally. Since code gen + // currently does not support optional of TensorList our approach is to replace + // backward in native_functions.yaml with _backward and call it here instead. + if (inputs.has_value()) { + TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty") + this->_backward(inputs.value(), gradient, retain_graph, create_graph); + } else { + this->_backward({}, gradient, retain_graph, create_graph); + } + } + + /// \fn Tensor detach() const; + /// + /// Returns a new Tensor, detached from the current graph. + /// The result will never require gradient. + + /// \fn Tensor & detach_() const; + /// + /// Detaches the Tensor from the graph that created it, making it a leaf. + /// Views cannot be detached in-place. + + /// \fn void retain_grad() const; + /// + /// Enables this Tensor to have their :attr:`grad` populated during + /// :func:`backward`. This is a no-op for leaf tensors. + + /// \fn bool retains_grad() const; + /// + /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be + /// populated during :func:`backward`, ``false`` otherwise. + + const Tensor& set_requires_grad(bool requires_grad) const { + TensorBase::set_requires_grad(requires_grad); + return *this; + } + + /// Return a mutable reference to the gradient. This is conventionally + /// used as `t.grad() = x` to set a gradient to a completely new tensor. + /// Note that this function work with a non-const Tensor and is not + /// thread safe. + Tensor& mutable_grad() const { + return impl_->mutable_grad(); + } + + /// This function returns an undefined tensor by default and returns a defined tensor + /// the first time a call to `backward()` computes gradients for this Tensor. + /// The attribute will then contain the gradients computed and future calls + /// to `backward()` will accumulate (add) gradients into it. + const Tensor& grad() const { + const Tensor& maybe_grad = impl_->grad(); + if (!is_leaf() && !retains_grad() && !maybe_grad.defined()) { + TORCH_WARN( + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad " + "attribute won't be populated during autograd.backward(). If you indeed want the .grad " + "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. " + "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor " + "instead. See github.com/pytorch/pytorch/pull/30531 for more informations."); + } + return maybe_grad; + } + + // The Forward AD API functions below are low level and are not to be used by end + // users who should use the API provided in torch/csrc/autograd.h + + /// This function returns the forward gradient for this Tensor at the given level. + const Tensor& _fw_grad(uint64_t level) const { + return impl_->_fw_grad(level, *this); + } + + /// This function can be used to set the value of the forward grad. + /// Note that the given new_grad might not be used directly if it has different + /// metadata (size/stride/storage offset) compared to this Tensor. In that case, + /// new_grad content will be copied into a new Tensor + void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const { + impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op); + } + + + // STOP. Thinking of adding a method here, which only makes use + // of other ATen methods? Define it in native_functions.yaml. + + //example + //Tensor * add(Tensor & b); + void __dispatch__backward(at::TensorList inputs, const ::std::optional & gradient={}, ::std::optional retain_graph=::std::nullopt, bool create_graph=false) const; + void __dispatch_set_data(const at::Tensor & new_data) const; + at::Tensor __dispatch_data() const; + bool __dispatch_is_leaf() const; + int64_t __dispatch_output_nr() const; + int64_t __dispatch__version() const; + at::Tensor & __dispatch_requires_grad_(bool requires_grad=true) const; + void __dispatch_retain_grad() const; + bool __dispatch_retains_grad() const; + at::Tensor _fw_primal(int64_t level) const; + at::Tensor & rename_(::std::optional names) const; + at::Tensor rename(::std::optional names) const; + at::Tensor align_to(at::DimnameList names) const; + at::Tensor align_to(at::DimnameList order, int64_t ellipsis_idx) const; + at::Tensor align_as(const at::Tensor & other) const; + at::Tensor refine_names(at::DimnameList names) const; + at::Tensor abs() const; + at::Tensor & abs_() const; + at::Tensor absolute() const; + at::Tensor & absolute_() const; + at::Tensor angle() const; + at::Tensor sgn() const; + at::Tensor & sgn_() const; + at::Tensor chalf(::std::optional memory_format=::std::nullopt) const; + at::Tensor _conj() const; + at::Tensor __dispatch_conj() const; + at::Tensor _conj_physical() const; + at::Tensor conj_physical() const; + at::Tensor & conj_physical_() const; + at::Tensor resolve_conj() const; + at::Tensor resolve_neg() const; + at::Tensor _neg_view() const; + at::Tensor acos() const; + at::Tensor & acos_() const; + at::Tensor arccos() const; + at::Tensor & arccos_() const; + at::Tensor add(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor & add_(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor add(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor & add_(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor addmv(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & addmv_(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor addr(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & addr_(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor _is_all_true() const; + at::Tensor _is_any_true() const; + at::Tensor all(int64_t dim, bool keepdim=false) const; + at::Tensor all(at::OptionalIntArrayRef dim, bool keepdim=false) const; + at::Tensor all(at::Dimname dim, bool keepdim=false) const; + bool allclose(const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; + at::Tensor any(int64_t dim, bool keepdim=false) const; + at::Tensor any(at::OptionalIntArrayRef dim, bool keepdim=false) const; + at::Tensor any(at::Dimname dim, bool keepdim=false) const; + at::Tensor argmax(::std::optional dim=::std::nullopt, bool keepdim=false) const; + at::Tensor argmin(::std::optional dim=::std::nullopt, bool keepdim=false) const; + at::Tensor acosh() const; + at::Tensor & acosh_() const; + at::Tensor arccosh() const; + at::Tensor & arccosh_() const; + at::Tensor asinh() const; + at::Tensor & asinh_() const; + at::Tensor arcsinh() const; + at::Tensor & arcsinh_() const; + at::Tensor atanh() const; + at::Tensor & atanh_() const; + at::Tensor arctanh() const; + at::Tensor & arctanh_() const; + at::Tensor as_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor as_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + const at::Tensor & as_strided_(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + const at::Tensor & as_strided__symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor asin() const; + at::Tensor & asin_() const; + at::Tensor arcsin() const; + at::Tensor & arcsin_() const; + at::Tensor atan() const; + at::Tensor & atan_() const; + at::Tensor arctan() const; + at::Tensor & arctan_() const; + at::Tensor baddbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & baddbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor bernoulli(::std::optional generator=::std::nullopt) const; + at::Tensor & bernoulli_(const at::Tensor & p, ::std::optional generator=::std::nullopt) const; + at::Tensor & bernoulli_(double p=0.5, ::std::optional generator=::std::nullopt) const; + at::Tensor bernoulli(double p, ::std::optional generator=::std::nullopt) const; + at::Tensor bincount(const ::std::optional & weights={}, int64_t minlength=0) const; + at::Tensor bincount_symint(const ::std::optional & weights={}, c10::SymInt minlength=0) const; + at::Tensor bitwise_not() const; + at::Tensor & bitwise_not_() const; + at::Tensor copysign(const at::Tensor & other) const; + at::Tensor & copysign_(const at::Tensor & other) const; + at::Tensor copysign(const at::Scalar & other) const; + at::Tensor & copysign_(const at::Scalar & other) const; + at::Tensor _lazy_clone() const; + at::Tensor logical_not() const; + at::Tensor & logical_not_() const; + at::Tensor logical_xor(const at::Tensor & other) const; + at::Tensor & logical_xor_(const at::Tensor & other) const; + at::Tensor logical_and(const at::Tensor & other) const; + at::Tensor & logical_and_(const at::Tensor & other) const; + at::Tensor logical_or(const at::Tensor & other) const; + at::Tensor & logical_or_(const at::Tensor & other) const; + at::Tensor bmm(const at::Tensor & mat2) const; + at::Tensor broadcast_to(at::IntArrayRef size) const; + at::Tensor broadcast_to_symint(c10::SymIntArrayRef size) const; + at::Tensor ceil() const; + at::Tensor & ceil_() const; + ::std::vector unsafe_chunk(int64_t chunks, int64_t dim=0) const; + ::std::vector chunk(int64_t chunks, int64_t dim=0) const; + ::std::vector tensor_split(int64_t sections, int64_t dim=0) const; + ::std::vector tensor_split_symint(c10::SymInt sections, int64_t dim=0) const; + ::std::vector tensor_split(at::IntArrayRef indices, int64_t dim=0) const; + ::std::vector tensor_split_symint(c10::SymIntArrayRef indices, int64_t dim=0) const; + ::std::vector tensor_split(const at::Tensor & tensor_indices_or_sections, int64_t dim=0) const; + at::Tensor clamp(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor clamp(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor & clamp_(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor & clamp_(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor clamp_max(const at::Scalar & max) const; + at::Tensor clamp_max(const at::Tensor & max) const; + at::Tensor & clamp_max_(const at::Scalar & max) const; + at::Tensor & clamp_max_(const at::Tensor & max) const; + at::Tensor clamp_min(const at::Scalar & min) const; + at::Tensor clamp_min(const at::Tensor & min) const; + at::Tensor & clamp_min_(const at::Scalar & min) const; + at::Tensor & clamp_min_(const at::Tensor & min) const; + at::Tensor clip(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor clip(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor & clip_(const ::std::optional & min, const ::std::optional & max=::std::nullopt) const; + at::Tensor & clip_(const ::std::optional & min={}, const ::std::optional & max={}) const; + at::Tensor __dispatch_contiguous(at::MemoryFormat memory_format=c10::MemoryFormat::Contiguous) const; + at::Tensor & copy_(const at::Tensor & src, bool non_blocking=false) const; + at::Tensor cos() const; + at::Tensor & cos_() const; + at::Tensor cosh() const; + at::Tensor & cosh_() const; + at::Tensor count_nonzero(at::IntArrayRef dim) const; + at::Tensor count_nonzero(::std::optional dim=::std::nullopt) const; + at::Tensor cov(int64_t correction=1, const ::std::optional & fweights={}, const ::std::optional & aweights={}) const; + at::Tensor corrcoef() const; + ::std::tuple cummax(int64_t dim) const; + ::std::tuple cummax(at::Dimname dim) const; + ::std::tuple cummin(int64_t dim) const; + ::std::tuple cummin(at::Dimname dim) const; + at::Tensor cumprod(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumprod_(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor cumprod(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumprod_(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor cumsum(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumsum_(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor cumsum(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor & cumsum_(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor diag_embed(int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) const; + at::Tensor diagflat(int64_t offset=0) const; + at::Tensor diagonal(int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const; + at::Tensor diagonal(at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset=0) const; + at::Tensor & fill_diagonal_(const at::Scalar & fill_value, bool wrap=false) const; + at::Tensor diff(int64_t n=1, int64_t dim=-1, const ::std::optional & prepend={}, const ::std::optional & append={}) const; + at::Tensor div(const at::Tensor & other) const; + at::Tensor & div_(const at::Tensor & other) const; + at::Tensor div(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor & div_(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor div(const at::Scalar & other) const; + at::Tensor & div_(const at::Scalar & other) const; + at::Tensor div(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor & div_(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor divide(const at::Tensor & other) const; + at::Tensor & divide_(const at::Tensor & other) const; + at::Tensor divide(const at::Scalar & other) const; + at::Tensor & divide_(const at::Scalar & other) const; + at::Tensor divide(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor & divide_(const at::Tensor & other, ::std::optional rounding_mode) const; + at::Tensor divide(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor & divide_(const at::Scalar & other, ::std::optional rounding_mode) const; + at::Tensor true_divide(const at::Tensor & other) const; + at::Tensor & true_divide_(const at::Tensor & other) const; + at::Tensor true_divide(const at::Scalar & other) const; + at::Tensor & true_divide_(const at::Scalar & other) const; + at::Tensor dot(const at::Tensor & tensor) const; + at::Tensor vdot(const at::Tensor & other) const; + at::Tensor new_empty(at::IntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_empty(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_empty_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_empty_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) const; + at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) const; + at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_full(at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) const; + at::Tensor new_full(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) const; + at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_zeros(at::IntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_zeros(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_zeros_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_zeros_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_ones(at::IntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_ones(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + at::Tensor new_ones_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const; + at::Tensor new_ones_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const; + const at::Tensor & resize_(at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) const; + const at::Tensor & resize__symint(c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) const; + at::Tensor erf() const; + at::Tensor & erf_() const; + at::Tensor erfc() const; + at::Tensor & erfc_() const; + at::Tensor exp() const; + at::Tensor & exp_() const; + at::Tensor exp2() const; + at::Tensor & exp2_() const; + at::Tensor expm1() const; + at::Tensor & expm1_() const; + at::Tensor expand(at::IntArrayRef size, bool implicit=false) const; + at::Tensor expand_symint(c10::SymIntArrayRef size, bool implicit=false) const; + at::Tensor expand_as(const at::Tensor & other) const; + at::Tensor flatten(int64_t start_dim=0, int64_t end_dim=-1) const; + at::Tensor flatten(int64_t start_dim, int64_t end_dim, at::Dimname out_dim) const; + at::Tensor flatten(at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) const; + at::Tensor flatten(at::DimnameList dims, at::Dimname out_dim) const; + at::Tensor unflatten(int64_t dim, at::IntArrayRef sizes) const; + at::Tensor unflatten_symint(int64_t dim, c10::SymIntArrayRef sizes) const; + at::Tensor unflatten(at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) const; + at::Tensor unflatten_symint(at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) const; + at::Tensor & fill_(const at::Scalar & value) const; + at::Tensor & fill_(const at::Tensor & value) const; + at::Tensor floor() const; + at::Tensor & floor_() const; + at::Tensor floor_divide(const at::Tensor & other) const; + at::Tensor & floor_divide_(const at::Tensor & other) const; + at::Tensor floor_divide(const at::Scalar & other) const; + at::Tensor & floor_divide_(const at::Scalar & other) const; + at::Tensor frac() const; + at::Tensor & frac_() const; + at::Tensor gcd(const at::Tensor & other) const; + at::Tensor & gcd_(const at::Tensor & other) const; + at::Tensor lcm(const at::Tensor & other) const; + at::Tensor & lcm_(const at::Tensor & other) const; + at::Tensor index(const c10::List<::std::optional> & indices) const; + at::Tensor & index_copy_(int64_t dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor index_copy(int64_t dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor & index_copy_(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor index_copy(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const; + at::Tensor & index_put_(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) const; + at::Tensor index_put(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) const; + at::Tensor isclose(const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; + at::Tensor isnan() const; + bool is_distributed() const; + bool __dispatch_is_floating_point() const; + bool __dispatch_is_complex() const; + bool __dispatch_is_conj() const; + bool __dispatch__is_zerotensor() const; + bool __dispatch_is_neg() const; + at::Tensor isreal() const; + bool is_nonzero() const; + bool is_same_size(const at::Tensor & other) const; + bool __dispatch_is_signed() const; + bool __dispatch_is_inference() const; + at::Tensor kron(const at::Tensor & other) const; + ::std::tuple kthvalue(int64_t k, int64_t dim=-1, bool keepdim=false) const; + ::std::tuple kthvalue_symint(c10::SymInt k, int64_t dim=-1, bool keepdim=false) const; + ::std::tuple kthvalue(int64_t k, at::Dimname dim, bool keepdim=false) const; + ::std::tuple kthvalue_symint(c10::SymInt k, at::Dimname dim, bool keepdim=false) const; + at::Tensor nan_to_num(::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) const; + at::Tensor & nan_to_num_(::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) const; + at::Tensor ldexp(const at::Tensor & other) const; + at::Tensor & ldexp_(const at::Tensor & other) const; + at::Tensor log() const; + at::Tensor & log_() const; + at::Tensor log10() const; + at::Tensor & log10_() const; + at::Tensor log1p() const; + at::Tensor & log1p_() const; + at::Tensor log2() const; + at::Tensor & log2_() const; + at::Tensor logaddexp(const at::Tensor & other) const; + at::Tensor logaddexp2(const at::Tensor & other) const; + at::Tensor xlogy(const at::Tensor & other) const; + at::Tensor xlogy(const at::Scalar & other) const; + at::Tensor & xlogy_(const at::Tensor & other) const; + at::Tensor & xlogy_(const at::Scalar & other) const; + at::Tensor log_softmax(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor log_softmax(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor logcumsumexp(int64_t dim) const; + at::Tensor logcumsumexp(at::Dimname dim) const; + at::Tensor logsumexp(at::IntArrayRef dim, bool keepdim=false) const; + at::Tensor logsumexp(at::DimnameList dim, bool keepdim=false) const; + at::Tensor matmul(const at::Tensor & other) const; + at::Tensor matrix_power(int64_t n) const; + at::Tensor matrix_exp() const; + ::std::tuple aminmax(::std::optional dim=::std::nullopt, bool keepdim=false) const; + ::std::tuple max(int64_t dim, bool keepdim=false) const; + ::std::tuple max(at::Dimname dim, bool keepdim=false) const; + at::Tensor amax(at::IntArrayRef dim={}, bool keepdim=false) const; + at::Tensor mean(::std::optional dtype=::std::nullopt) const; + at::Tensor mean(at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor mean(at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor nanmean(at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor median() const; + ::std::tuple median(int64_t dim, bool keepdim=false) const; + ::std::tuple median(at::Dimname dim, bool keepdim=false) const; + at::Tensor nanmedian() const; + ::std::tuple nanmedian(int64_t dim, bool keepdim=false) const; + ::std::tuple nanmedian(at::Dimname dim, bool keepdim=false) const; + ::std::tuple min(int64_t dim, bool keepdim=false) const; + ::std::tuple min(at::Dimname dim, bool keepdim=false) const; + at::Tensor amin(at::IntArrayRef dim={}, bool keepdim=false) const; + at::Tensor mm(const at::Tensor & mat2) const; + ::std::tuple mode(int64_t dim=-1, bool keepdim=false) const; + ::std::tuple mode(at::Dimname dim, bool keepdim=false) const; + at::Tensor mul(const at::Tensor & other) const; + at::Tensor & mul_(const at::Tensor & other) const; + at::Tensor mul(const at::Scalar & other) const; + at::Tensor & mul_(const at::Scalar & other) const; + at::Tensor multiply(const at::Tensor & other) const; + at::Tensor & multiply_(const at::Tensor & other) const; + at::Tensor multiply(const at::Scalar & other) const; + at::Tensor & multiply_(const at::Scalar & other) const; + at::Tensor mv(const at::Tensor & vec) const; + at::Tensor mvlgamma(int64_t p) const; + at::Tensor & mvlgamma_(int64_t p) const; + at::Tensor narrow_copy(int64_t dim, int64_t start, int64_t length) const; + at::Tensor narrow_copy_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const; + at::Tensor narrow(int64_t dim, int64_t start, int64_t length) const; + at::Tensor narrow_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const; + at::Tensor narrow(int64_t dim, const at::Tensor & start, int64_t length) const; + at::Tensor narrow_symint(int64_t dim, const at::Tensor & start, c10::SymInt length) const; + at::Tensor permute(at::IntArrayRef dims) const; + at::Tensor movedim(at::IntArrayRef source, at::IntArrayRef destination) const; + at::Tensor movedim(int64_t source, int64_t destination) const; + at::Tensor moveaxis(at::IntArrayRef source, at::IntArrayRef destination) const; + at::Tensor moveaxis(int64_t source, int64_t destination) const; + at::Tensor numpy_T() const; + at::Tensor matrix_H() const; + at::Tensor mT() const; + at::Tensor mH() const; + at::Tensor adjoint() const; + bool is_pinned(::std::optional device=::std::nullopt) const; + at::Tensor pin_memory(::std::optional device=::std::nullopt) const; + at::Tensor pinverse(double rcond=1e-15) const; + at::Tensor rad2deg() const; + at::Tensor & rad2deg_() const; + at::Tensor deg2rad() const; + at::Tensor & deg2rad_() const; + at::Tensor ravel() const; + at::Tensor reciprocal() const; + at::Tensor & reciprocal_() const; + at::Tensor neg() const; + at::Tensor & neg_() const; + at::Tensor negative() const; + at::Tensor & negative_() const; + at::Tensor repeat(at::IntArrayRef repeats) const; + at::Tensor repeat_symint(c10::SymIntArrayRef repeats) const; + at::Tensor repeat_interleave(const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor repeat_interleave_symint(const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor repeat_interleave(int64_t repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor repeat_interleave_symint(c10::SymInt repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) const; + at::Tensor reshape(at::IntArrayRef shape) const; + at::Tensor reshape_symint(c10::SymIntArrayRef shape) const; + at::Tensor _reshape_alias(at::IntArrayRef size, at::IntArrayRef stride) const; + at::Tensor _reshape_alias_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const; + at::Tensor reshape_as(const at::Tensor & other) const; + at::Tensor round() const; + at::Tensor & round_() const; + at::Tensor round(int64_t decimals) const; + at::Tensor & round_(int64_t decimals) const; + at::Tensor relu() const; + at::Tensor & relu_() const; + at::Tensor prelu(const at::Tensor & weight) const; + at::Tensor hardshrink(const at::Scalar & lambd=0.5) const; + at::Tensor hardshrink_backward(const at::Tensor & grad_out, const at::Scalar & lambd) const; + at::Tensor rsqrt() const; + at::Tensor & rsqrt_() const; + at::Tensor select(at::Dimname dim, int64_t index) const; + at::Tensor select(int64_t dim, int64_t index) const; + at::Tensor select_symint(int64_t dim, c10::SymInt index) const; + at::Tensor sigmoid() const; + at::Tensor & sigmoid_() const; + at::Tensor logit(::std::optional eps=::std::nullopt) const; + at::Tensor & logit_(::std::optional eps=::std::nullopt) const; + at::Tensor sin() const; + at::Tensor & sin_() const; + at::Tensor sinc() const; + at::Tensor & sinc_() const; + at::Tensor sinh() const; + at::Tensor & sinh_() const; + at::Tensor detach() const; + at::Tensor & detach_() const; + int64_t size(at::Dimname dim) const; + at::Tensor slice(int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) const; + at::Tensor slice_symint(int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) const; + at::Tensor slice_inverse(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) const; + at::Tensor slice_inverse_symint(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) const; + at::Tensor slice_scatter(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) const; + at::Tensor slice_scatter_symint(const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) const; + at::Tensor select_scatter(const at::Tensor & src, int64_t dim, int64_t index) const; + at::Tensor select_scatter_symint(const at::Tensor & src, int64_t dim, c10::SymInt index) const; + at::Tensor diagonal_scatter(const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const; + at::Tensor as_strided_scatter(const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor as_strided_scatter_symint(const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) const; + at::Tensor smm(const at::Tensor & mat2) const; + at::Tensor softmax(int64_t dim, ::std::optional dtype=::std::nullopt) const; + at::Tensor softmax(at::Dimname dim, ::std::optional dtype=::std::nullopt) const; + ::std::vector unsafe_split(int64_t split_size, int64_t dim=0) const; + ::std::vector unsafe_split_symint(c10::SymInt split_size, int64_t dim=0) const; + ::std::vector split(int64_t split_size, int64_t dim=0) const; + ::std::vector split_symint(c10::SymInt split_size, int64_t dim=0) const; + ::std::vector split(at::IntArrayRef split_size, int64_t dim=0) const; + ::std::vector split_symint(c10::SymIntArrayRef split_size, int64_t dim=0) const; + ::std::vector unsafe_split_with_sizes(at::IntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector unsafe_split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector split_with_sizes(at::IntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim=0) const; + ::std::vector hsplit(int64_t sections) const; + ::std::vector hsplit(at::IntArrayRef indices) const; + ::std::vector vsplit(int64_t sections) const; + ::std::vector vsplit(at::IntArrayRef indices) const; + ::std::vector dsplit(int64_t sections) const; + ::std::vector dsplit(at::IntArrayRef indices) const; + at::Tensor squeeze() const; + at::Tensor squeeze(int64_t dim) const; + at::Tensor squeeze(at::Dimname dim) const; + at::Tensor squeeze(at::IntArrayRef dim) const; + at::Tensor & squeeze_() const; + at::Tensor & squeeze_(int64_t dim) const; + at::Tensor & squeeze_(at::IntArrayRef dim) const; + at::Tensor & squeeze_(at::Dimname dim) const; + at::Tensor sspaddmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor stft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) const; + at::Tensor stft(int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, c10::string_view pad_mode="reflect", bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) const; + at::Tensor istft(int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional length=::std::nullopt, bool return_complex=false) const; + int64_t stride(at::Dimname dim) const; + at::Tensor sum(::std::optional dtype=::std::nullopt) const; + at::Tensor sum(at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor sum(at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor nansum(at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor sum_to_size(at::IntArrayRef size) const; + at::Tensor sum_to_size_symint(c10::SymIntArrayRef size) const; + at::Tensor sqrt() const; + at::Tensor & sqrt_() const; + at::Tensor square() const; + at::Tensor & square_() const; + at::Tensor std(bool unbiased) const; + at::Tensor std(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) const; + at::Tensor std(at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor std(at::DimnameList dim, bool unbiased, bool keepdim=false) const; + at::Tensor std(at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor prod(::std::optional dtype=::std::nullopt) const; + at::Tensor prod(int64_t dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor prod(at::Dimname dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) const; + at::Tensor t() const; + at::Tensor & t_() const; + at::Tensor tan() const; + at::Tensor & tan_() const; + at::Tensor tanh() const; + at::Tensor & tanh_() const; + at::Tensor tile(at::IntArrayRef dims) const; + at::Tensor tile_symint(c10::SymIntArrayRef dims) const; + at::Tensor transpose(int64_t dim0, int64_t dim1) const; + at::Tensor transpose(at::Dimname dim0, at::Dimname dim1) const; + at::Tensor & transpose_(int64_t dim0, int64_t dim1) const; + at::Tensor flip(at::IntArrayRef dims) const; + at::Tensor fliplr() const; + at::Tensor flipud() const; + at::Tensor roll(at::IntArrayRef shifts, at::IntArrayRef dims={}) const; + at::Tensor roll_symint(c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) const; + at::Tensor rot90(int64_t k=1, at::IntArrayRef dims={0,1}) const; + at::Tensor _nested_tensor_size() const; + at::Tensor _nested_tensor_strides() const; + at::Tensor _nested_tensor_storage_offsets() const; + at::Tensor trunc() const; + at::Tensor & trunc_() const; + at::Tensor fix() const; + at::Tensor & fix_() const; + at::Tensor type_as(const at::Tensor & other) const; + at::Tensor unsqueeze(int64_t dim) const; + at::Tensor & unsqueeze_(int64_t dim) const; + at::Tensor var(bool unbiased) const; + at::Tensor var(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) const; + at::Tensor var(at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor var(at::DimnameList dim, bool unbiased, bool keepdim=false) const; + at::Tensor var(at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) const; + at::Tensor view_as(const at::Tensor & other) const; + at::Tensor where(const at::Tensor & condition, const at::Tensor & other) const; + at::Tensor where(const at::Tensor & condition, const at::Scalar & other) const; + at::Tensor norm(const ::std::optional & p, at::ScalarType dtype) const; + at::Tensor norm(const at::Scalar & p=2) const; + at::Tensor norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) const; + at::Tensor norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim=false) const; + at::Tensor norm(const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) const; + at::Tensor norm(const ::std::optional & p, at::DimnameList dim, bool keepdim=false) const; + ::std::tuple frexp() const; + at::Tensor clone(::std::optional memory_format=::std::nullopt) const; + at::Tensor positive() const; + const at::Tensor & resize_as_(const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) const; + const at::Tensor & resize_as_sparse_(const at::Tensor & the_template) const; + at::Tensor & zero_() const; + at::Tensor sub(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor & sub_(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor sub(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor & sub_(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor subtract(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor & subtract_(const at::Tensor & other, const at::Scalar & alpha=1) const; + at::Tensor subtract(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor & subtract_(const at::Scalar & other, const at::Scalar & alpha=1) const; + at::Tensor heaviside(const at::Tensor & values) const; + at::Tensor & heaviside_(const at::Tensor & values) const; + at::Tensor addmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & addmm_(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor _addmm_activation(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) const; + const at::Tensor & sparse_resize_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const; + const at::Tensor & sparse_resize_and_clear_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const; + at::Tensor sparse_mask(const at::Tensor & mask) const; + at::Tensor _sparse_mask_projection(const at::Tensor & mask, bool accumulate_matches=false) const; + at::Tensor to_dense(::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) const; + at::Tensor _to_dense(::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) const; + int64_t sparse_dim() const; + int64_t _dimI() const; + int64_t dense_dim() const; + int64_t _dimV() const; + int64_t _nnz() const; + at::Tensor coalesce() const; + bool is_coalesced() const; + at::Tensor _indices() const; + at::Tensor _values() const; + at::Tensor & _coalesced_(bool coalesced) const; + at::Tensor indices() const; + at::Tensor values() const; + at::Tensor crow_indices() const; + at::Tensor col_indices() const; + at::Tensor ccol_indices() const; + at::Tensor row_indices() const; + ::std::vector unbind(int64_t dim=0) const; + ::std::vector unbind(at::Dimname dim) const; + at::Tensor to_sparse(int64_t sparse_dim) const; + at::Tensor _to_sparse(int64_t sparse_dim) const; + at::Tensor to_sparse(::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse(::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_csr(::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_csr(::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_csc(::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_csc(::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor _to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) const; + at::Tensor to_mkldnn(::std::optional dtype=::std::nullopt) const; + at::Tensor dequantize() const; + double q_scale() const; + int64_t q_zero_point() const; + at::Tensor q_per_channel_scales() const; + at::Tensor q_per_channel_zero_points() const; + int64_t q_per_channel_axis() const; + at::Tensor int_repr() const; + at::QScheme qscheme() const; + at::Tensor _autocast_to_reduced_precision(bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) const; + at::Tensor _autocast_to_full_precision(bool cuda_enabled, bool cpu_enabled) const; + at::Tensor to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Tensor to(::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) const; + at::Tensor to(at::Device device, at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Tensor to(at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Tensor to(const at::Tensor & other, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) const; + at::Scalar item() const; + at::Tensor & set_(at::Storage source) const; + at::Tensor & set_(at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) const; + at::Tensor & set__symint(at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) const; + at::Tensor & set_(const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) const; + at::Tensor & set__symint(const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) const; + at::Tensor & set_(const at::Tensor & source) const; + at::Tensor & set_() const; + bool is_set_to(const at::Tensor & tensor) const; + at::Tensor & masked_fill_(const at::Tensor & mask, const at::Scalar & value) const; + at::Tensor masked_fill(const at::Tensor & mask, const at::Scalar & value) const; + at::Tensor & masked_fill_(const at::Tensor & mask, const at::Tensor & value) const; + at::Tensor masked_fill(const at::Tensor & mask, const at::Tensor & value) const; + at::Tensor & masked_scatter_(const at::Tensor & mask, const at::Tensor & source) const; + at::Tensor masked_scatter(const at::Tensor & mask, const at::Tensor & source) const; + at::Tensor view(at::IntArrayRef size) const; + at::Tensor view_symint(c10::SymIntArrayRef size) const; + at::Tensor view(at::ScalarType dtype) const; + at::Tensor & put_(const at::Tensor & index, const at::Tensor & source, bool accumulate=false) const; + at::Tensor put(const at::Tensor & index, const at::Tensor & source, bool accumulate=false) const; + at::Tensor & index_add_(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const; + at::Tensor index_add(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const; + at::Tensor index_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const; + at::Tensor & index_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) const; + at::Tensor index_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) const; + at::Tensor & index_fill_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor index_fill(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor & index_fill_(int64_t dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor index_fill(int64_t dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor & index_fill_(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor & index_fill_(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor index_fill(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor index_fill(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const; + at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const; + at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const; + at::Tensor scatter(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const; + at::Tensor scatter_add(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor & scatter_add_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const; + at::Tensor scatter_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) const; + at::Tensor & scatter_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) const; + at::Tensor & eq_(const at::Scalar & other) const; + at::Tensor & eq_(const at::Tensor & other) const; + at::Tensor bitwise_and(const at::Scalar & other) const; + at::Tensor bitwise_and(const at::Tensor & other) const; + at::Tensor & bitwise_and_(const at::Scalar & other) const; + at::Tensor & bitwise_and_(const at::Tensor & other) const; + at::Tensor __and__(const at::Scalar & other) const; + at::Tensor __and__(const at::Tensor & other) const; + at::Tensor & __iand__(const at::Scalar & other) const; + at::Tensor & __iand__(const at::Tensor & other) const; + at::Tensor bitwise_or(const at::Scalar & other) const; + at::Tensor bitwise_or(const at::Tensor & other) const; + at::Tensor & bitwise_or_(const at::Scalar & other) const; + at::Tensor & bitwise_or_(const at::Tensor & other) const; + at::Tensor __or__(const at::Scalar & other) const; + at::Tensor __or__(const at::Tensor & other) const; + at::Tensor & __ior__(const at::Scalar & other) const; + at::Tensor & __ior__(const at::Tensor & other) const; + at::Tensor bitwise_xor(const at::Scalar & other) const; + at::Tensor bitwise_xor(const at::Tensor & other) const; + at::Tensor & bitwise_xor_(const at::Scalar & other) const; + at::Tensor & bitwise_xor_(const at::Tensor & other) const; + at::Tensor __xor__(const at::Scalar & other) const; + at::Tensor __xor__(const at::Tensor & other) const; + at::Tensor & __ixor__(const at::Scalar & other) const; + at::Tensor & __ixor__(const at::Tensor & other) const; + at::Tensor __lshift__(const at::Scalar & other) const; + at::Tensor __lshift__(const at::Tensor & other) const; + at::Tensor & __ilshift__(const at::Scalar & other) const; + at::Tensor & __ilshift__(const at::Tensor & other) const; + at::Tensor bitwise_left_shift(const at::Tensor & other) const; + at::Tensor & bitwise_left_shift_(const at::Tensor & other) const; + at::Tensor bitwise_left_shift(const at::Scalar & other) const; + at::Tensor & bitwise_left_shift_(const at::Scalar & other) const; + at::Tensor __rshift__(const at::Scalar & other) const; + at::Tensor __rshift__(const at::Tensor & other) const; + at::Tensor & __irshift__(const at::Scalar & other) const; + at::Tensor & __irshift__(const at::Tensor & other) const; + at::Tensor bitwise_right_shift(const at::Tensor & other) const; + at::Tensor & bitwise_right_shift_(const at::Tensor & other) const; + at::Tensor bitwise_right_shift(const at::Scalar & other) const; + at::Tensor & bitwise_right_shift_(const at::Scalar & other) const; + at::Tensor & tril_(int64_t diagonal=0) const; + at::Tensor & triu_(int64_t diagonal=0) const; + at::Tensor & digamma_() const; + at::Tensor & lerp_(const at::Tensor & end, const at::Scalar & weight) const; + at::Tensor & lerp_(const at::Tensor & end, const at::Tensor & weight) const; + at::Tensor & addbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor addbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const; + at::Tensor & random_(int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) const; + at::Tensor & random_(int64_t to, ::std::optional generator=::std::nullopt) const; + at::Tensor & random_(::std::optional generator=::std::nullopt) const; + at::Tensor & uniform_(double from=0, double to=1, ::std::optional generator=::std::nullopt) const; + at::Tensor & cauchy_(double median=0, double sigma=1, ::std::optional generator=::std::nullopt) const; + at::Tensor & log_normal_(double mean=1, double std=2, ::std::optional generator=::std::nullopt) const; + at::Tensor & exponential_(double lambd=1, ::std::optional generator=::std::nullopt) const; + at::Tensor & geometric_(double p, ::std::optional generator=::std::nullopt) const; + at::Tensor diag(int64_t diagonal=0) const; + at::Tensor cross(const at::Tensor & other, ::std::optional dim=::std::nullopt) const; + at::Tensor triu(int64_t diagonal=0) const; + at::Tensor tril(int64_t diagonal=0) const; + at::Tensor trace() const; + at::Tensor ne(const at::Scalar & other) const; + at::Tensor ne(const at::Tensor & other) const; + at::Tensor & ne_(const at::Scalar & other) const; + at::Tensor & ne_(const at::Tensor & other) const; + at::Tensor not_equal(const at::Scalar & other) const; + at::Tensor not_equal(const at::Tensor & other) const; + at::Tensor & not_equal_(const at::Scalar & other) const; + at::Tensor & not_equal_(const at::Tensor & other) const; + at::Tensor eq(const at::Scalar & other) const; + at::Tensor eq(const at::Tensor & other) const; + at::Tensor ge(const at::Scalar & other) const; + at::Tensor ge(const at::Tensor & other) const; + at::Tensor & ge_(const at::Scalar & other) const; + at::Tensor & ge_(const at::Tensor & other) const; + at::Tensor greater_equal(const at::Scalar & other) const; + at::Tensor greater_equal(const at::Tensor & other) const; + at::Tensor & greater_equal_(const at::Scalar & other) const; + at::Tensor & greater_equal_(const at::Tensor & other) const; + at::Tensor le(const at::Scalar & other) const; + at::Tensor le(const at::Tensor & other) const; + at::Tensor & le_(const at::Scalar & other) const; + at::Tensor & le_(const at::Tensor & other) const; + at::Tensor less_equal(const at::Scalar & other) const; + at::Tensor less_equal(const at::Tensor & other) const; + at::Tensor & less_equal_(const at::Scalar & other) const; + at::Tensor & less_equal_(const at::Tensor & other) const; + at::Tensor gt(const at::Scalar & other) const; + at::Tensor gt(const at::Tensor & other) const; + at::Tensor & gt_(const at::Scalar & other) const; + at::Tensor & gt_(const at::Tensor & other) const; + at::Tensor greater(const at::Scalar & other) const; + at::Tensor greater(const at::Tensor & other) const; + at::Tensor & greater_(const at::Scalar & other) const; + at::Tensor & greater_(const at::Tensor & other) const; + at::Tensor lt(const at::Scalar & other) const; + at::Tensor lt(const at::Tensor & other) const; + at::Tensor & lt_(const at::Scalar & other) const; + at::Tensor & lt_(const at::Tensor & other) const; + at::Tensor less(const at::Scalar & other) const; + at::Tensor less(const at::Tensor & other) const; + at::Tensor & less_(const at::Scalar & other) const; + at::Tensor & less_(const at::Tensor & other) const; + at::Tensor take(const at::Tensor & index) const; + at::Tensor take_along_dim(const at::Tensor & indices, ::std::optional dim=::std::nullopt) const; + at::Tensor index_select(int64_t dim, const at::Tensor & index) const; + at::Tensor index_select(at::Dimname dim, const at::Tensor & index) const; + at::Tensor masked_select(const at::Tensor & mask) const; + at::Tensor nonzero() const; + at::Tensor nonzero_static(int64_t size, int64_t fill_value=-1) const; + at::Tensor nonzero_static_symint(c10::SymInt size, int64_t fill_value=-1) const; + ::std::vector nonzero_numpy() const; + at::Tensor argwhere() const; + at::Tensor gather(int64_t dim, const at::Tensor & index, bool sparse_grad=false) const; + at::Tensor gather(at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) const; + at::Tensor addcmul(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + at::Tensor & addcmul_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + at::Tensor addcdiv(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + at::Tensor & addcdiv_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const; + ::std::tuple triangular_solve(const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const; + ::std::tuple svd(bool some=true, bool compute_uv=true) const; + at::Tensor swapaxes(int64_t axis0, int64_t axis1) const; + at::Tensor & swapaxes_(int64_t axis0, int64_t axis1) const; + at::Tensor swapdims(int64_t dim0, int64_t dim1) const; + at::Tensor & swapdims_(int64_t dim0, int64_t dim1) const; + at::Tensor cholesky(bool upper=false) const; + at::Tensor cholesky_solve(const at::Tensor & input2, bool upper=false) const; + at::Tensor cholesky_inverse(bool upper=false) const; + ::std::tuple qr(bool some=true) const; + ::std::tuple geqrf() const; + at::Tensor orgqr(const at::Tensor & input2) const; + at::Tensor ormqr(const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) const; + at::Tensor lu_solve(const at::Tensor & LU_data, const at::Tensor & LU_pivots) const; + at::Tensor multinomial(int64_t num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) const; + at::Tensor multinomial_symint(c10::SymInt num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) const; + at::Tensor & lgamma_() const; + at::Tensor lgamma() const; + at::Tensor digamma() const; + at::Tensor polygamma(int64_t n) const; + at::Tensor & polygamma_(int64_t n) const; + at::Tensor erfinv() const; + at::Tensor & erfinv_() const; + at::Tensor i0() const; + at::Tensor & i0_() const; + at::Tensor sign() const; + at::Tensor & sign_() const; + at::Tensor signbit() const; + at::Tensor dist(const at::Tensor & other, const at::Scalar & p=2) const; + at::Tensor & atan2_(const at::Tensor & other) const; + at::Tensor atan2(const at::Tensor & other) const; + at::Tensor arctan2(const at::Tensor & other) const; + at::Tensor & arctan2_(const at::Tensor & other) const; + at::Tensor lerp(const at::Tensor & end, const at::Scalar & weight) const; + at::Tensor lerp(const at::Tensor & end, const at::Tensor & weight) const; + at::Tensor histc(int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) const; + ::std::tuple histogram(const at::Tensor & bins, const ::std::optional & weight={}, bool density=false) const; + ::std::tuple histogram(int64_t bins=100, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) const; + at::Tensor fmod(const at::Scalar & other) const; + at::Tensor & fmod_(const at::Scalar & other) const; + at::Tensor fmod(const at::Tensor & other) const; + at::Tensor & fmod_(const at::Tensor & other) const; + at::Tensor hypot(const at::Tensor & other) const; + at::Tensor & hypot_(const at::Tensor & other) const; + at::Tensor igamma(const at::Tensor & other) const; + at::Tensor & igamma_(const at::Tensor & other) const; + at::Tensor igammac(const at::Tensor & other) const; + at::Tensor & igammac_(const at::Tensor & other) const; + at::Tensor nextafter(const at::Tensor & other) const; + at::Tensor & nextafter_(const at::Tensor & other) const; + at::Tensor remainder(const at::Scalar & other) const; + at::Tensor & remainder_(const at::Scalar & other) const; + at::Tensor remainder(const at::Tensor & other) const; + at::Tensor & remainder_(const at::Tensor & other) const; + at::Tensor min() const; + at::Tensor fmin(const at::Tensor & other) const; + at::Tensor max() const; + at::Tensor fmax(const at::Tensor & other) const; + at::Tensor maximum(const at::Tensor & other) const; + at::Tensor max(const at::Tensor & other) const; + at::Tensor minimum(const at::Tensor & other) const; + at::Tensor min(const at::Tensor & other) const; + at::Tensor quantile(const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + at::Tensor quantile(double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + at::Tensor nanquantile(const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + at::Tensor nanquantile(double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const; + ::std::tuple sort(int64_t dim=-1, bool descending=false) const; + ::std::tuple sort(::std::optional stable, int64_t dim=-1, bool descending=false) const; + ::std::tuple sort(at::Dimname dim, bool descending=false) const; + ::std::tuple sort(::std::optional stable, at::Dimname dim, bool descending=false) const; + at::Tensor msort() const; + at::Tensor argsort(int64_t dim=-1, bool descending=false) const; + at::Tensor argsort(bool stable, int64_t dim=-1, bool descending=false) const; + at::Tensor argsort(at::Dimname dim, bool descending=false) const; + ::std::tuple topk(int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) const; + ::std::tuple topk_symint(c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) const; + at::Tensor all() const; + at::Tensor any() const; + at::Tensor renorm(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const; + at::Tensor & renorm_(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const; + at::Tensor unfold(int64_t dimension, int64_t size, int64_t step) const; + bool equal(const at::Tensor & other) const; + at::Tensor pow(const at::Tensor & exponent) const; + at::Tensor pow(const at::Scalar & exponent) const; + at::Tensor & pow_(const at::Scalar & exponent) const; + at::Tensor & pow_(const at::Tensor & exponent) const; + at::Tensor float_power(const at::Tensor & exponent) const; + at::Tensor float_power(const at::Scalar & exponent) const; + at::Tensor & float_power_(const at::Scalar & exponent) const; + at::Tensor & float_power_(const at::Tensor & exponent) const; + at::Tensor & normal_(double mean=0, double std=1, ::std::optional generator=::std::nullopt) const; + at::Tensor alias() const; + at::Tensor isfinite() const; + at::Tensor isinf() const; + void record_stream(at::Stream s) const; + at::Tensor isposinf() const; + at::Tensor isneginf() const; + at::Tensor det() const; + ::std::tuple slogdet() const; + at::Tensor logdet() const; + at::Tensor inverse() const; + at::Tensor inner(const at::Tensor & other) const; + at::Tensor outer(const at::Tensor & vec2) const; + at::Tensor ger(const at::Tensor & vec2) const; + at::Tensor to_padded_tensor(double padding, at::OptionalIntArrayRef output_size=::std::nullopt) const; + at::Tensor to_padded_tensor_symint(double padding, at::OptionalSymIntArrayRef output_size=::std::nullopt) const; + + // Special C++ only overloads for std()-like functions (See gh-40287) + // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef + // So, for example std(0) would select the std(unbiased=False) overload + + Tensor var(int dim) const { + return var(IntArrayRef{dim}); + } + + Tensor std(int dim) const { + return std(IntArrayRef{dim}); + } + + // We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the + // at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet. + // Before that change, we make this method to maintain BC for C++ usage like + // `x.to(y.dtype)`. + // TODO: remove following two after at::kDouble and its friends are TypeMeta's. + inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { + return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); + } + inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { + return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); + } + + template + decltype(auto) m(F func, Args&&... params) const { + return func(*this, std::forward(params)...); + } + + /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended + /// to be used from functions that need to access the `Variable`'s equivalent `Tensor` + /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`). + /// + /// One notable difference with the legacy `.data()` function is that changes to the + /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) + /// will not update the original `Variable`, due to the fact that this function + /// shallow-copies the `Variable`'s underlying TensorImpl. + at::Tensor tensor_data() const { + return TensorBase::tensor_data(); + } + + /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` + /// in Python, which create a new `Variable` that shares the same storage and + /// tensor metadata with the original `Variable`, but with a completely new + /// autograd history. + /// + /// NOTE: If we change the tensor metadata (e.g. sizes / strides / + /// storage / storage_offset) of a variable created from `var.variable_data()`, those + /// changes will not update the original variable `var`. In `.variable_data()`, we set + /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, + /// in order to prevent users from changing metadata of `var.variable_data()` + /// and expecting the original variable `var` to also be updated. + at::Tensor variable_data() const { + return TensorBase::variable_data(); + } + + // Hooks + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + using hook_return_void_t = std::enable_if_t>::value, unsigned>; + template + using hook_return_var_t = std::enable_if_t, Tensor>, unsigned>; + + /// Registers a backward hook. + /// + /// The hook will be called every time a gradient with respect to the Tensor is computed. + /// The hook should have one of the following signature: + /// ``` + /// hook(Tensor grad) -> Tensor + /// ``` + /// ``` + /// hook(Tensor grad) -> void + /// ``` + /// The hook should not modify its argument, but it can optionally return a new gradient + /// which will be used in place of `grad`. + /// + /// This function returns the index of the hook in the list which can be used to remove hook. + /// + /// Example: + /// @code + /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); + /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient + /// v.backward(torch::tensor({1., 2., 3.})); + /// // This prints: + /// // ``` + /// // 2 + /// // 4 + /// // 6 + /// // [ CPUFloatType{3} ] + /// // ``` + /// std::cout << v.grad() << std::endl; + /// v.remove_hook(h); // removes the hook + /// @endcode + template + hook_return_void_t register_hook(T&& hook) const; + template + hook_return_var_t register_hook(T&& hook) const; + + // Variable methods + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Tensor data() const { + return TensorBase::data(); + } + + void _backward(TensorList inputs, const std::optional& gradient, std::optional keep_graph, bool create_graph) const; + + const Tensor& requires_grad_(bool _requires_grad=true) const { + TensorBase::requires_grad_(_requires_grad); + return *this; + } +}; + +namespace detail { +// Helper creator for Tensor class which doesn't requires the users to pass +// in an intrusive_ptr instead it just converts the argument passed to +// requested intrusive_ptr type. +template +Tensor make_tensor(Args&&... args) { + return Tensor(c10::make_intrusive(std::forward(args)...)); +} + +} // namespace detail + +} // namespace at + + +namespace at { + +// aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () +inline void Tensor::__dispatch__backward(at::TensorList inputs, const ::std::optional & gradient, ::std::optional retain_graph, bool create_graph) const { + return at::_ops::_backward::call(const_cast(*this), inputs, gradient, retain_graph, create_graph); +} + +// aten::set_data(Tensor(a!) self, Tensor new_data) -> () +inline void Tensor::__dispatch_set_data(const at::Tensor & new_data) const { + return at::_ops::set_data::call(const_cast(*this), new_data); +} + +// aten::data(Tensor self) -> Tensor +inline at::Tensor Tensor::__dispatch_data() const { + return at::_ops::data::call(const_cast(*this)); +} + +// aten::is_leaf(Tensor self) -> bool +inline bool Tensor::__dispatch_is_leaf() const { + return at::_ops::is_leaf::call(const_cast(*this)); +} + +// aten::output_nr(Tensor self) -> int +inline int64_t Tensor::__dispatch_output_nr() const { + return at::_ops::output_nr::call(const_cast(*this)); +} + +// aten::_version(Tensor self) -> int +inline int64_t Tensor::__dispatch__version() const { + return at::_ops::_version::call(const_cast(*this)); +} + +// aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!) +inline at::Tensor & Tensor::__dispatch_requires_grad_(bool requires_grad) const { + return at::_ops::requires_grad_::call(const_cast(*this), requires_grad); +} + +// aten::retain_grad(Tensor(a!) self) -> () +inline void Tensor::__dispatch_retain_grad() const { + return at::_ops::retain_grad::call(const_cast(*this)); +} + +// aten::retains_grad(Tensor self) -> bool +inline bool Tensor::__dispatch_retains_grad() const { + return at::_ops::retains_grad::call(const_cast(*this)); +} + +// aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a) +inline at::Tensor Tensor::_fw_primal(int64_t level) const { + return at::_ops::_fw_primal::call(const_cast(*this), level); +} + +// aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) +inline at::Tensor & Tensor::rename_(::std::optional names) const { + return at::_ops::rename_::call(const_cast(*this), names); +} + +// aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) +inline at::Tensor Tensor::rename(::std::optional names) const { + return at::_ops::rename::call(const_cast(*this), names); +} + +// aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::align_to(at::DimnameList names) const { + return at::_ops::align_to::call(const_cast(*this), names); +} + +// aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a) +inline at::Tensor Tensor::align_to(at::DimnameList order, int64_t ellipsis_idx) const { + return at::_ops::align_to_ellipsis_idx::call(const_cast(*this), order, ellipsis_idx); +} + +// aten::align_as(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::align_as(const at::Tensor & other) const { + return at::_ops::align_as::call(const_cast(*this), other); +} + +// aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::refine_names(at::DimnameList names) const { + return at::_ops::refine_names::call(const_cast(*this), names); +} + +// aten::abs(Tensor self) -> Tensor +inline at::Tensor Tensor::abs() const { + return at::_ops::abs::call(const_cast(*this)); +} + +// aten::abs_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::abs_() const { + return at::_ops::abs_::call(const_cast(*this)); +} + +// aten::absolute(Tensor self) -> Tensor +inline at::Tensor Tensor::absolute() const { + return at::_ops::absolute::call(const_cast(*this)); +} + +// aten::absolute_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::absolute_() const { + return at::_ops::absolute_::call(const_cast(*this)); +} + +// aten::angle(Tensor self) -> Tensor +inline at::Tensor Tensor::angle() const { + return at::_ops::angle::call(const_cast(*this)); +} + +// aten::sgn(Tensor self) -> Tensor +inline at::Tensor Tensor::sgn() const { + return at::_ops::sgn::call(const_cast(*this)); +} + +// aten::sgn_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sgn_() const { + return at::_ops::sgn_::call(const_cast(*this)); +} + +// aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor +inline at::Tensor Tensor::chalf(::std::optional memory_format) const { + return at::_ops::chalf::call(const_cast(*this), memory_format); +} + +// aten::_conj(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_conj() const { + return at::_ops::_conj::call(const_cast(*this)); +} + +// aten::conj(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::__dispatch_conj() const { + return at::_ops::conj::call(const_cast(*this)); +} + +// aten::_conj_physical(Tensor self) -> Tensor +inline at::Tensor Tensor::_conj_physical() const { + return at::_ops::_conj_physical::call(const_cast(*this)); +} + +// aten::conj_physical(Tensor self) -> Tensor +inline at::Tensor Tensor::conj_physical() const { + return at::_ops::conj_physical::call(const_cast(*this)); +} + +// aten::conj_physical_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::conj_physical_() const { + return at::_ops::conj_physical_::call(const_cast(*this)); +} + +// aten::resolve_conj(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::resolve_conj() const { + return at::_ops::resolve_conj::call(const_cast(*this)); +} + +// aten::resolve_neg(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::resolve_neg() const { + return at::_ops::resolve_neg::call(const_cast(*this)); +} + +// aten::_neg_view(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_neg_view() const { + return at::_ops::_neg_view::call(const_cast(*this)); +} + +// aten::acos(Tensor self) -> Tensor +inline at::Tensor Tensor::acos() const { + return at::_ops::acos::call(const_cast(*this)); +} + +// aten::acos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::acos_() const { + return at::_ops::acos_::call(const_cast(*this)); +} + +// aten::arccos(Tensor self) -> Tensor +inline at::Tensor Tensor::arccos() const { + return at::_ops::arccos::call(const_cast(*this)); +} + +// aten::arccos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arccos_() const { + return at::_ops::arccos_::call(const_cast(*this)); +} + +// aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::add_Tensor::call(const_cast(*this), other, alpha); +} + +// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::add_(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::add__Tensor::call(const_cast(*this), other, alpha); +} + +// aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::add(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::add_Scalar::call(const_cast(*this), other, alpha); +} + +// aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::add_(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::add__Scalar::call(const_cast(*this), other, alpha); +} + +// aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addmv(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmv::call(const_cast(*this), mat, vec, beta, alpha); +} + +// aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addmv_(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmv_::call(const_cast(*this), mat, vec, beta, alpha); +} + +// aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addr(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addr::call(const_cast(*this), vec1, vec2, beta, alpha); +} + +// aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addr_(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addr_::call(const_cast(*this), vec1, vec2, beta, alpha); +} + +// aten::_is_all_true(Tensor self) -> Tensor +inline at::Tensor Tensor::_is_all_true() const { + return at::_ops::_is_all_true::call(const_cast(*this)); +} + +// aten::_is_any_true(Tensor self) -> Tensor +inline at::Tensor Tensor::_is_any_true() const { + return at::_ops::_is_any_true::call(const_cast(*this)); +} + +// aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::all(int64_t dim, bool keepdim) const { + return at::_ops::all_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::all(at::OptionalIntArrayRef dim, bool keepdim) const { + return at::_ops::all_dims::call(const_cast(*this), dim, keepdim); +} + +// aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::all(at::Dimname dim, bool keepdim) const { + return at::_ops::all_dimname::call(const_cast(*this), dim, keepdim); +} + +// aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool +inline bool Tensor::allclose(const at::Tensor & other, double rtol, double atol, bool equal_nan) const { + return at::_ops::allclose::call(const_cast(*this), other, rtol, atol, equal_nan); +} + +// aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::any(int64_t dim, bool keepdim) const { + return at::_ops::any_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::any(at::OptionalIntArrayRef dim, bool keepdim) const { + return at::_ops::any_dims::call(const_cast(*this), dim, keepdim); +} + +// aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::any(at::Dimname dim, bool keepdim) const { + return at::_ops::any_dimname::call(const_cast(*this), dim, keepdim); +} + +// aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::argmax(::std::optional dim, bool keepdim) const { + return at::_ops::argmax::call(const_cast(*this), dim, keepdim); +} + +// aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::argmin(::std::optional dim, bool keepdim) const { + return at::_ops::argmin::call(const_cast(*this), dim, keepdim); +} + +// aten::acosh(Tensor self) -> Tensor +inline at::Tensor Tensor::acosh() const { + return at::_ops::acosh::call(const_cast(*this)); +} + +// aten::acosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::acosh_() const { + return at::_ops::acosh_::call(const_cast(*this)); +} + +// aten::arccosh(Tensor self) -> Tensor +inline at::Tensor Tensor::arccosh() const { + return at::_ops::arccosh::call(const_cast(*this)); +} + +// aten::arccosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arccosh_() const { + return at::_ops::arccosh_::call(const_cast(*this)); +} + +// aten::asinh(Tensor self) -> Tensor +inline at::Tensor Tensor::asinh() const { + return at::_ops::asinh::call(const_cast(*this)); +} + +// aten::asinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::asinh_() const { + return at::_ops::asinh_::call(const_cast(*this)); +} + +// aten::arcsinh(Tensor self) -> Tensor +inline at::Tensor Tensor::arcsinh() const { + return at::_ops::arcsinh::call(const_cast(*this)); +} + +// aten::arcsinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arcsinh_() const { + return at::_ops::arcsinh_::call(const_cast(*this)); +} + +// aten::atanh(Tensor self) -> Tensor +inline at::Tensor Tensor::atanh() const { + return at::_ops::atanh::call(const_cast(*this)); +} + +// aten::atanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::atanh_() const { + return at::_ops::atanh_::call(const_cast(*this)); +} + +// aten::arctanh(Tensor self) -> Tensor +inline at::Tensor Tensor::arctanh() const { + return at::_ops::arctanh::call(const_cast(*this)); +} + +// aten::arctanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arctanh_() const { + return at::_ops::arctanh_::call(const_cast(*this)); +} + +// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) +inline at::Tensor Tensor::as_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} + +// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) +inline at::Tensor Tensor::as_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided::call(const_cast(*this), size, stride, storage_offset); +} + +// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) +inline const at::Tensor & Tensor::as_strided_(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} + +// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) +inline const at::Tensor & Tensor::as_strided__symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_::call(const_cast(*this), size, stride, storage_offset); +} + +// aten::asin(Tensor self) -> Tensor +inline at::Tensor Tensor::asin() const { + return at::_ops::asin::call(const_cast(*this)); +} + +// aten::asin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::asin_() const { + return at::_ops::asin_::call(const_cast(*this)); +} + +// aten::arcsin(Tensor self) -> Tensor +inline at::Tensor Tensor::arcsin() const { + return at::_ops::arcsin::call(const_cast(*this)); +} + +// aten::arcsin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arcsin_() const { + return at::_ops::arcsin_::call(const_cast(*this)); +} + +// aten::atan(Tensor self) -> Tensor +inline at::Tensor Tensor::atan() const { + return at::_ops::atan::call(const_cast(*this)); +} + +// aten::atan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::atan_() const { + return at::_ops::atan_::call(const_cast(*this)); +} + +// aten::arctan(Tensor self) -> Tensor +inline at::Tensor Tensor::arctan() const { + return at::_ops::arctan::call(const_cast(*this)); +} + +// aten::arctan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::arctan_() const { + return at::_ops::arctan_::call(const_cast(*this)); +} + +// aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::baddbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::baddbmm::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::baddbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::baddbmm_::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::bernoulli(::std::optional generator) const { + return at::_ops::bernoulli::call(const_cast(*this), generator); +} + +// aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::bernoulli_(const at::Tensor & p, ::std::optional generator) const { + return at::_ops::bernoulli__Tensor::call(const_cast(*this), p, generator); +} + +// aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::bernoulli_(double p, ::std::optional generator) const { + return at::_ops::bernoulli__float::call(const_cast(*this), p, generator); +} + +// aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::bernoulli(double p, ::std::optional generator) const { + return at::_ops::bernoulli_p::call(const_cast(*this), p, generator); +} + +// aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor +inline at::Tensor Tensor::bincount(const ::std::optional & weights, int64_t minlength) const { + return at::_ops::bincount::call(const_cast(*this), weights, minlength); +} + +// aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor +inline at::Tensor Tensor::bincount_symint(const ::std::optional & weights, c10::SymInt minlength) const { + return at::_ops::bincount::call(const_cast(*this), weights, minlength); +} + +// aten::bitwise_not(Tensor self) -> Tensor +inline at::Tensor Tensor::bitwise_not() const { + return at::_ops::bitwise_not::call(const_cast(*this)); +} + +// aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_not_() const { + return at::_ops::bitwise_not_::call(const_cast(*this)); +} + +// aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::copysign(const at::Tensor & other) const { + return at::_ops::copysign_Tensor::call(const_cast(*this), other); +} + +// aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::copysign_(const at::Tensor & other) const { + return at::_ops::copysign__Tensor::call(const_cast(*this), other); +} + +// aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::copysign(const at::Scalar & other) const { + return at::_ops::copysign_Scalar::call(const_cast(*this), other); +} + +// aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::copysign_(const at::Scalar & other) const { + return at::_ops::copysign__Scalar::call(const_cast(*this), other); +} + +// aten::_lazy_clone(Tensor self) -> Tensor +inline at::Tensor Tensor::_lazy_clone() const { + return at::_ops::_lazy_clone::call(const_cast(*this)); +} + +// aten::logical_not(Tensor self) -> Tensor +inline at::Tensor Tensor::logical_not() const { + return at::_ops::logical_not::call(const_cast(*this)); +} + +// aten::logical_not_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::logical_not_() const { + return at::_ops::logical_not_::call(const_cast(*this)); +} + +// aten::logical_xor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logical_xor(const at::Tensor & other) const { + return at::_ops::logical_xor::call(const_cast(*this), other); +} + +// aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::logical_xor_(const at::Tensor & other) const { + return at::_ops::logical_xor_::call(const_cast(*this), other); +} + +// aten::logical_and(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logical_and(const at::Tensor & other) const { + return at::_ops::logical_and::call(const_cast(*this), other); +} + +// aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::logical_and_(const at::Tensor & other) const { + return at::_ops::logical_and_::call(const_cast(*this), other); +} + +// aten::logical_or(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logical_or(const at::Tensor & other) const { + return at::_ops::logical_or::call(const_cast(*this), other); +} + +// aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::logical_or_(const at::Tensor & other) const { + return at::_ops::logical_or_::call(const_cast(*this), other); +} + +// aten::bmm(Tensor self, Tensor mat2) -> Tensor +inline at::Tensor Tensor::bmm(const at::Tensor & mat2) const { + return at::_ops::bmm::call(const_cast(*this), mat2); +} + +// aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::broadcast_to(at::IntArrayRef size) const { + return at::_ops::broadcast_to::call(const_cast(*this), c10::fromIntArrayRefSlow(size)); +} + +// aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::broadcast_to_symint(c10::SymIntArrayRef size) const { + return at::_ops::broadcast_to::call(const_cast(*this), size); +} + +// aten::ceil(Tensor self) -> Tensor +inline at::Tensor Tensor::ceil() const { + return at::_ops::ceil::call(const_cast(*this)); +} + +// aten::ceil_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::ceil_() const { + return at::_ops::ceil_::call(const_cast(*this)); +} + +// aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_chunk(int64_t chunks, int64_t dim) const { + return at::_ops::unsafe_chunk::call(const_cast(*this), chunks, dim); +} + +// aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { + return at::_ops::chunk::call(const_cast(*this), chunks, dim); +} + +// aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split(int64_t sections, int64_t dim) const { + return at::_ops::tensor_split_sections::call(const_cast(*this), sections, dim); +} + +// aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split_symint(c10::SymInt sections, int64_t dim) const { + return at::_ops::tensor_split_sections::call(const_cast(*this), sections, dim); +} + +// aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split(at::IntArrayRef indices, int64_t dim) const { + return at::_ops::tensor_split_indices::call(const_cast(*this), c10::fromIntArrayRefSlow(indices), dim); +} + +// aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split_symint(c10::SymIntArrayRef indices, int64_t dim) const { + return at::_ops::tensor_split_indices::call(const_cast(*this), indices, dim); +} + +// aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::tensor_split(const at::Tensor & tensor_indices_or_sections, int64_t dim) const { + return at::_ops::tensor_split_tensor_indices_or_sections::call(const_cast(*this), tensor_indices_or_sections, dim); +} + +// aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor +inline at::Tensor Tensor::clamp(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp::call(const_cast(*this), min, max); +} + +// aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor +inline at::Tensor Tensor::clamp(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp_Tensor::call(const_cast(*this), min, max); +} + +// aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp_::call(const_cast(*this), min, max); +} + +// aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clamp__Tensor::call(const_cast(*this), min, max); +} + +// aten::clamp_max(Tensor self, Scalar max) -> Tensor +inline at::Tensor Tensor::clamp_max(const at::Scalar & max) const { + return at::_ops::clamp_max::call(const_cast(*this), max); +} + +// aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor +inline at::Tensor Tensor::clamp_max(const at::Tensor & max) const { + return at::_ops::clamp_max_Tensor::call(const_cast(*this), max); +} + +// aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_max_(const at::Scalar & max) const { + return at::_ops::clamp_max_::call(const_cast(*this), max); +} + +// aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_max_(const at::Tensor & max) const { + return at::_ops::clamp_max__Tensor::call(const_cast(*this), max); +} + +// aten::clamp_min(Tensor self, Scalar min) -> Tensor +inline at::Tensor Tensor::clamp_min(const at::Scalar & min) const { + return at::_ops::clamp_min::call(const_cast(*this), min); +} + +// aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor +inline at::Tensor Tensor::clamp_min(const at::Tensor & min) const { + return at::_ops::clamp_min_Tensor::call(const_cast(*this), min); +} + +// aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_min_(const at::Scalar & min) const { + return at::_ops::clamp_min_::call(const_cast(*this), min); +} + +// aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!) +inline at::Tensor & Tensor::clamp_min_(const at::Tensor & min) const { + return at::_ops::clamp_min__Tensor::call(const_cast(*this), min); +} + +// aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor +inline at::Tensor Tensor::clip(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip::call(const_cast(*this), min, max); +} + +// aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor +inline at::Tensor Tensor::clip(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip_Tensor::call(const_cast(*this), min, max); +} + +// aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clip_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip_::call(const_cast(*this), min, max); +} + +// aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) +inline at::Tensor & Tensor::clip_(const ::std::optional & min, const ::std::optional & max) const { + return at::_ops::clip__Tensor::call(const_cast(*this), min, max); +} + +// aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) +inline at::Tensor Tensor::__dispatch_contiguous(at::MemoryFormat memory_format) const { + return at::_ops::contiguous::call(const_cast(*this), memory_format); +} + +// aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) +inline at::Tensor & Tensor::copy_(const at::Tensor & src, bool non_blocking) const { + return at::_ops::copy_::call(const_cast(*this), src, non_blocking); +} + +// aten::cos(Tensor self) -> Tensor +inline at::Tensor Tensor::cos() const { + return at::_ops::cos::call(const_cast(*this)); +} + +// aten::cos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::cos_() const { + return at::_ops::cos_::call(const_cast(*this)); +} + +// aten::cosh(Tensor self) -> Tensor +inline at::Tensor Tensor::cosh() const { + return at::_ops::cosh::call(const_cast(*this)); +} + +// aten::cosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::cosh_() const { + return at::_ops::cosh_::call(const_cast(*this)); +} + +// aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor +inline at::Tensor Tensor::count_nonzero(at::IntArrayRef dim) const { + return at::_ops::count_nonzero_dim_IntList::call(const_cast(*this), dim); +} + +// aten::count_nonzero(Tensor self, int? dim=None) -> Tensor +inline at::Tensor Tensor::count_nonzero(::std::optional dim) const { + return at::_ops::count_nonzero::call(const_cast(*this), dim); +} + +// aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor +inline at::Tensor Tensor::cov(int64_t correction, const ::std::optional & fweights, const ::std::optional & aweights) const { + return at::_ops::cov::call(const_cast(*this), correction, fweights, aweights); +} + +// aten::corrcoef(Tensor self) -> Tensor +inline at::Tensor Tensor::corrcoef() const { + return at::_ops::corrcoef::call(const_cast(*this)); +} + +// aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummax(int64_t dim) const { + return at::_ops::cummax::call(const_cast(*this), dim); +} + +// aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummax(at::Dimname dim) const { + return at::_ops::cummax_dimname::call(const_cast(*this), dim); +} + +// aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummin(int64_t dim) const { + return at::_ops::cummin::call(const_cast(*this), dim); +} + +// aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::cummin(at::Dimname dim) const { + return at::_ops::cummin_dimname::call(const_cast(*this), dim); +} + +// aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumprod(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumprod::call(const_cast(*this), dim, dtype); +} + +// aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumprod_(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumprod_::call(const_cast(*this), dim, dtype); +} + +// aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumprod(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumprod_dimname::call(const_cast(*this), dim, dtype); +} + +// aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumprod_(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumprod__dimname::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumsum(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumsum::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumsum_(int64_t dim, ::std::optional dtype) const { + return at::_ops::cumsum_::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::cumsum(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumsum_dimname::call(const_cast(*this), dim, dtype); +} + +// aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) +inline at::Tensor & Tensor::cumsum_(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::cumsum__dimname::call(const_cast(*this), dim, dtype); +} + +// aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor +inline at::Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) const { + return at::_ops::diag_embed::call(const_cast(*this), offset, dim1, dim2); +} + +// aten::diagflat(Tensor self, int offset=0) -> Tensor +inline at::Tensor Tensor::diagflat(int64_t offset) const { + return at::_ops::diagflat::call(const_cast(*this), offset); +} + +// aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) +inline at::Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const { + return at::_ops::diagonal::call(const_cast(*this), offset, dim1, dim2); +} + +// aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a) +inline at::Tensor Tensor::diagonal(at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset) const { + return at::_ops::diagonal_Dimname::call(const_cast(*this), outdim, dim1, dim2, offset); +} + +// aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) +inline at::Tensor & Tensor::fill_diagonal_(const at::Scalar & fill_value, bool wrap) const { + return at::_ops::fill_diagonal_::call(const_cast(*this), fill_value, wrap); +} + +// aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor +inline at::Tensor Tensor::diff(int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append) const { + return at::_ops::diff::call(const_cast(*this), n, dim, prepend, append); +} + +// aten::div.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::div(const at::Tensor & other) const { + return at::_ops::div_Tensor::call(const_cast(*this), other); +} + +// aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Tensor & other) const { + return at::_ops::div__Tensor::call(const_cast(*this), other); +} + +// aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::div(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::div_Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::div__Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::div.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::div(const at::Scalar & other) const { + return at::_ops::div_Scalar::call(const_cast(*this), other); +} + +// aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Scalar & other) const { + return at::_ops::div__Scalar::call(const_cast(*this), other); +} + +// aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::div(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::div_Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::div_(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::div__Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::divide(const at::Tensor & other) const { + return at::_ops::divide_Tensor::call(const_cast(*this), other); +} + +// aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Tensor & other) const { + return at::_ops::divide__Tensor::call(const_cast(*this), other); +} + +// aten::divide.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::divide(const at::Scalar & other) const { + return at::_ops::divide_Scalar::call(const_cast(*this), other); +} + +// aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Scalar & other) const { + return at::_ops::divide__Scalar::call(const_cast(*this), other); +} + +// aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::divide(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::divide_Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Tensor & other, ::std::optional rounding_mode) const { + return at::_ops::divide__Tensor_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor +inline at::Tensor Tensor::divide(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::divide_Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) +inline at::Tensor & Tensor::divide_(const at::Scalar & other, ::std::optional rounding_mode) const { + return at::_ops::divide__Scalar_mode::call(const_cast(*this), other, rounding_mode); +} + +// aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::true_divide(const at::Tensor & other) const { + return at::_ops::true_divide_Tensor::call(const_cast(*this), other); +} + +// aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::true_divide_(const at::Tensor & other) const { + return at::_ops::true_divide__Tensor::call(const_cast(*this), other); +} + +// aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::true_divide(const at::Scalar & other) const { + return at::_ops::true_divide_Scalar::call(const_cast(*this), other); +} + +// aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::true_divide_(const at::Scalar & other) const { + return at::_ops::true_divide__Scalar::call(const_cast(*this), other); +} + +// aten::dot(Tensor self, Tensor tensor) -> Tensor +inline at::Tensor Tensor::dot(const at::Tensor & tensor) const { + return at::_ops::dot::call(const_cast(*this), tensor); +} + +// aten::vdot(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::vdot(const at::Tensor & other) const { + return at::_ops::vdot::call(const_cast(*this), other); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty(at::IntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_empty::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_symint(c10::SymIntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_empty::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty::call(const_cast(*this), size, dtype, layout, device, pin_memory); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options) const { + return at::_ops::new_empty_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options) const { + return at::_ops::new_empty_strided::call(const_cast(*this), size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_empty_strided::call(const_cast(*this), size, stride, dtype, layout, device, pin_memory); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full(at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options) const { + return at::_ops::new_full::call(const_cast(*this), c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_full::call(const_cast(*this), c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options) const { + return at::_ops::new_full::call(const_cast(*this), size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_full::call(const_cast(*this), size, fill_value, dtype, layout, device, pin_memory); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros(at::IntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_zeros::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_zeros::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros_symint(c10::SymIntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_zeros::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_zeros_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_zeros::call(const_cast(*this), size, dtype, layout, device, pin_memory); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones(at::IntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_ones::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_ones::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones_symint(c10::SymIntArrayRef size, at::TensorOptions options) const { + return at::_ops::new_ones::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); +} + +// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +inline at::Tensor Tensor::new_ones_symint(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) const { + return at::_ops::new_ones::call(const_cast(*this), size, dtype, layout, device, pin_memory); +} + +// aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) +inline const at::Tensor & Tensor::resize_(at::IntArrayRef size, ::std::optional memory_format) const { + return at::_ops::resize_::call(const_cast(*this), c10::fromIntArrayRefSlow(size), memory_format); +} + +// aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) +inline const at::Tensor & Tensor::resize__symint(c10::SymIntArrayRef size, ::std::optional memory_format) const { + return at::_ops::resize_::call(const_cast(*this), size, memory_format); +} + +// aten::erf(Tensor self) -> Tensor +inline at::Tensor Tensor::erf() const { + return at::_ops::erf::call(const_cast(*this)); +} + +// aten::erf_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::erf_() const { + return at::_ops::erf_::call(const_cast(*this)); +} + +// aten::erfc(Tensor self) -> Tensor +inline at::Tensor Tensor::erfc() const { + return at::_ops::erfc::call(const_cast(*this)); +} + +// aten::erfc_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::erfc_() const { + return at::_ops::erfc_::call(const_cast(*this)); +} + +// aten::exp(Tensor self) -> Tensor +inline at::Tensor Tensor::exp() const { + return at::_ops::exp::call(const_cast(*this)); +} + +// aten::exp_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::exp_() const { + return at::_ops::exp_::call(const_cast(*this)); +} + +// aten::exp2(Tensor self) -> Tensor +inline at::Tensor Tensor::exp2() const { + return at::_ops::exp2::call(const_cast(*this)); +} + +// aten::exp2_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::exp2_() const { + return at::_ops::exp2_::call(const_cast(*this)); +} + +// aten::expm1(Tensor self) -> Tensor +inline at::Tensor Tensor::expm1() const { + return at::_ops::expm1::call(const_cast(*this)); +} + +// aten::expm1_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::expm1_() const { + return at::_ops::expm1_::call(const_cast(*this)); +} + +// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) +inline at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit) const { + return at::_ops::expand::call(const_cast(*this), c10::fromIntArrayRefSlow(size), implicit); +} + +// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) +inline at::Tensor Tensor::expand_symint(c10::SymIntArrayRef size, bool implicit) const { + return at::_ops::expand::call(const_cast(*this), size, implicit); +} + +// aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a) +inline at::Tensor Tensor::expand_as(const at::Tensor & other) const { + return at::_ops::expand_as::call(const_cast(*this), other); +} + +// aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) +inline at::Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { + return at::_ops::flatten_using_ints::call(const_cast(*this), start_dim, end_dim); +} + +// aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a) +inline at::Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, at::Dimname out_dim) const { + return at::_ops::flatten_named_out_dim::call(const_cast(*this), start_dim, end_dim, out_dim); +} + +// aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a) +inline at::Tensor Tensor::flatten(at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) const { + return at::_ops::flatten_using_names::call(const_cast(*this), start_dim, end_dim, out_dim); +} + +// aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) +inline at::Tensor Tensor::flatten(at::DimnameList dims, at::Dimname out_dim) const { + return at::_ops::flatten_DimnameList::call(const_cast(*this), dims, out_dim); +} + +// aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) +inline at::Tensor Tensor::unflatten(int64_t dim, at::IntArrayRef sizes) const { + return at::_ops::unflatten_int::call(const_cast(*this), dim, c10::fromIntArrayRefSlow(sizes)); +} + +// aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) +inline at::Tensor Tensor::unflatten_symint(int64_t dim, c10::SymIntArrayRef sizes) const { + return at::_ops::unflatten_int::call(const_cast(*this), dim, sizes); +} + +// aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::unflatten(at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) const { + return at::_ops::unflatten_Dimname::call(const_cast(*this), dim, c10::fromIntArrayRefSlow(sizes), names); +} + +// aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) +inline at::Tensor Tensor::unflatten_symint(at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) const { + return at::_ops::unflatten_Dimname::call(const_cast(*this), dim, sizes, names); +} + +// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::fill_(const at::Scalar & value) const { + return at::_ops::fill__Scalar::call(const_cast(*this), value); +} + +// aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::fill_(const at::Tensor & value) const { + return at::_ops::fill__Tensor::call(const_cast(*this), value); +} + +// aten::floor(Tensor self) -> Tensor +inline at::Tensor Tensor::floor() const { + return at::_ops::floor::call(const_cast(*this)); +} + +// aten::floor_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::floor_() const { + return at::_ops::floor_::call(const_cast(*this)); +} + +// aten::floor_divide(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::floor_divide(const at::Tensor & other) const { + return at::_ops::floor_divide::call(const_cast(*this), other); +} + +// aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::floor_divide_(const at::Tensor & other) const { + return at::_ops::floor_divide__Tensor::call(const_cast(*this), other); +} + +// aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::floor_divide(const at::Scalar & other) const { + return at::_ops::floor_divide_Scalar::call(const_cast(*this), other); +} + +// aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::floor_divide_(const at::Scalar & other) const { + return at::_ops::floor_divide__Scalar::call(const_cast(*this), other); +} + +// aten::frac(Tensor self) -> Tensor +inline at::Tensor Tensor::frac() const { + return at::_ops::frac::call(const_cast(*this)); +} + +// aten::frac_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::frac_() const { + return at::_ops::frac_::call(const_cast(*this)); +} + +// aten::gcd(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::gcd(const at::Tensor & other) const { + return at::_ops::gcd::call(const_cast(*this), other); +} + +// aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::gcd_(const at::Tensor & other) const { + return at::_ops::gcd_::call(const_cast(*this), other); +} + +// aten::lcm(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::lcm(const at::Tensor & other) const { + return at::_ops::lcm::call(const_cast(*this), other); +} + +// aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::lcm_(const at::Tensor & other) const { + return at::_ops::lcm_::call(const_cast(*this), other); +} + +// aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor +inline at::Tensor Tensor::index(const c10::List<::std::optional> & indices) const { + return at::_ops::index_Tensor::call(const_cast(*this), indices); +} + +// aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::index_copy_(int64_t dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy_::call(const_cast(*this), dim, index, source); +} + +// aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor +inline at::Tensor Tensor::index_copy(int64_t dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy::call(const_cast(*this), dim, index, source); +} + +// aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::index_copy_(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy__dimname::call(const_cast(*this), dim, index, source); +} + +// aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor +inline at::Tensor Tensor::index_copy(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const { + return at::_ops::index_copy_dimname::call(const_cast(*this), dim, index, source); +} + +// aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) +inline at::Tensor & Tensor::index_put_(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) const { + return at::_ops::index_put_::call(const_cast(*this), indices, values, accumulate); +} + +// aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor +inline at::Tensor Tensor::index_put(const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) const { + return at::_ops::index_put::call(const_cast(*this), indices, values, accumulate); +} + +// aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor +inline at::Tensor Tensor::isclose(const at::Tensor & other, double rtol, double atol, bool equal_nan) const { + return at::_ops::isclose::call(const_cast(*this), other, rtol, atol, equal_nan); +} + +// aten::isnan(Tensor self) -> Tensor +inline at::Tensor Tensor::isnan() const { + return at::_ops::isnan::call(const_cast(*this)); +} + +// aten::is_distributed(Tensor self) -> bool +inline bool Tensor::is_distributed() const { + return at::_ops::is_distributed::call(const_cast(*this)); +} + +// aten::is_floating_point(Tensor self) -> bool +inline bool Tensor::__dispatch_is_floating_point() const { + return at::_ops::is_floating_point::call(const_cast(*this)); +} + +// aten::is_complex(Tensor self) -> bool +inline bool Tensor::__dispatch_is_complex() const { + return at::_ops::is_complex::call(const_cast(*this)); +} + +// aten::is_conj(Tensor self) -> bool +inline bool Tensor::__dispatch_is_conj() const { + return at::_ops::is_conj::call(const_cast(*this)); +} + +// aten::_is_zerotensor(Tensor self) -> bool +inline bool Tensor::__dispatch__is_zerotensor() const { + return at::_ops::_is_zerotensor::call(const_cast(*this)); +} + +// aten::is_neg(Tensor self) -> bool +inline bool Tensor::__dispatch_is_neg() const { + return at::_ops::is_neg::call(const_cast(*this)); +} + +// aten::isreal(Tensor self) -> Tensor +inline at::Tensor Tensor::isreal() const { + return at::_ops::isreal::call(const_cast(*this)); +} + +// aten::is_nonzero(Tensor self) -> bool +inline bool Tensor::is_nonzero() const { + return at::_ops::is_nonzero::call(const_cast(*this)); +} + +// aten::is_same_size(Tensor self, Tensor other) -> bool +inline bool Tensor::is_same_size(const at::Tensor & other) const { + return at::_ops::is_same_size::call(const_cast(*this), other); +} + +// aten::is_signed(Tensor self) -> bool +inline bool Tensor::__dispatch_is_signed() const { + return at::_ops::is_signed::call(const_cast(*this)); +} + +// aten::is_inference(Tensor self) -> bool +inline bool Tensor::__dispatch_is_inference() const { + return at::_ops::is_inference::call(const_cast(*this)); +} + +// aten::kron(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::kron(const at::Tensor & other) const { + return at::_ops::kron::call(const_cast(*this), other); +} + +// aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool keepdim) const { + return at::_ops::kthvalue::call(const_cast(*this), k, dim, keepdim); +} + +// aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue_symint(c10::SymInt k, int64_t dim, bool keepdim) const { + return at::_ops::kthvalue::call(const_cast(*this), k, dim, keepdim); +} + +// aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue(int64_t k, at::Dimname dim, bool keepdim) const { + return at::_ops::kthvalue_dimname::call(const_cast(*this), k, dim, keepdim); +} + +// aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::kthvalue_symint(c10::SymInt k, at::Dimname dim, bool keepdim) const { + return at::_ops::kthvalue_dimname::call(const_cast(*this), k, dim, keepdim); +} + +// aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor +inline at::Tensor Tensor::nan_to_num(::std::optional nan, ::std::optional posinf, ::std::optional neginf) const { + return at::_ops::nan_to_num::call(const_cast(*this), nan, posinf, neginf); +} + +// aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) +inline at::Tensor & Tensor::nan_to_num_(::std::optional nan, ::std::optional posinf, ::std::optional neginf) const { + return at::_ops::nan_to_num_::call(const_cast(*this), nan, posinf, neginf); +} + +// aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::ldexp(const at::Tensor & other) const { + return at::_ops::ldexp_Tensor::call(const_cast(*this), other); +} + +// aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::ldexp_(const at::Tensor & other) const { + return at::_ops::ldexp_::call(const_cast(*this), other); +} + +// aten::log(Tensor self) -> Tensor +inline at::Tensor Tensor::log() const { + return at::_ops::log::call(const_cast(*this)); +} + +// aten::log_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log_() const { + return at::_ops::log_::call(const_cast(*this)); +} + +// aten::log10(Tensor self) -> Tensor +inline at::Tensor Tensor::log10() const { + return at::_ops::log10::call(const_cast(*this)); +} + +// aten::log10_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log10_() const { + return at::_ops::log10_::call(const_cast(*this)); +} + +// aten::log1p(Tensor self) -> Tensor +inline at::Tensor Tensor::log1p() const { + return at::_ops::log1p::call(const_cast(*this)); +} + +// aten::log1p_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log1p_() const { + return at::_ops::log1p_::call(const_cast(*this)); +} + +// aten::log2(Tensor self) -> Tensor +inline at::Tensor Tensor::log2() const { + return at::_ops::log2::call(const_cast(*this)); +} + +// aten::log2_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::log2_() const { + return at::_ops::log2_::call(const_cast(*this)); +} + +// aten::logaddexp(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logaddexp(const at::Tensor & other) const { + return at::_ops::logaddexp::call(const_cast(*this), other); +} + +// aten::logaddexp2(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::logaddexp2(const at::Tensor & other) const { + return at::_ops::logaddexp2::call(const_cast(*this), other); +} + +// aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::xlogy(const at::Tensor & other) const { + return at::_ops::xlogy_Tensor::call(const_cast(*this), other); +} + +// aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::xlogy(const at::Scalar & other) const { + return at::_ops::xlogy_Scalar_Other::call(const_cast(*this), other); +} + +// aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::xlogy_(const at::Tensor & other) const { + return at::_ops::xlogy__Tensor::call(const_cast(*this), other); +} + +// aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::xlogy_(const at::Scalar & other) const { + return at::_ops::xlogy__Scalar_Other::call(const_cast(*this), other); +} + +// aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::log_softmax(int64_t dim, ::std::optional dtype) const { + return at::_ops::log_softmax_int::call(const_cast(*this), dim, dtype); +} + +// aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::log_softmax(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::log_softmax_Dimname::call(const_cast(*this), dim, dtype); +} + +// aten::logcumsumexp(Tensor self, int dim) -> Tensor +inline at::Tensor Tensor::logcumsumexp(int64_t dim) const { + return at::_ops::logcumsumexp::call(const_cast(*this), dim); +} + +// aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor +inline at::Tensor Tensor::logcumsumexp(at::Dimname dim) const { + return at::_ops::logcumsumexp_dimname::call(const_cast(*this), dim); +} + +// aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::logsumexp(at::IntArrayRef dim, bool keepdim) const { + return at::_ops::logsumexp::call(const_cast(*this), dim, keepdim); +} + +// aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::logsumexp(at::DimnameList dim, bool keepdim) const { + return at::_ops::logsumexp_names::call(const_cast(*this), dim, keepdim); +} + +// aten::matmul(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::matmul(const at::Tensor & other) const { + return at::_ops::matmul::call(const_cast(*this), other); +} + +// aten::matrix_power(Tensor self, int n) -> Tensor +inline at::Tensor Tensor::matrix_power(int64_t n) const { + return at::_ops::matrix_power::call(const_cast(*this), n); +} + +// aten::matrix_exp(Tensor self) -> Tensor +inline at::Tensor Tensor::matrix_exp() const { + return at::_ops::matrix_exp::call(const_cast(*this)); +} + +// aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) +inline ::std::tuple Tensor::aminmax(::std::optional dim, bool keepdim) const { + return at::_ops::aminmax::call(const_cast(*this), dim, keepdim); +} + +// aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::max(int64_t dim, bool keepdim) const { + return at::_ops::max_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::max(at::Dimname dim, bool keepdim) const { + return at::_ops::max_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +inline at::Tensor Tensor::amax(at::IntArrayRef dim, bool keepdim) const { + return at::_ops::amax::call(const_cast(*this), dim, keepdim); +} + +// aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::mean(::std::optional dtype) const { + return at::_ops::mean::call(const_cast(*this), dtype); +} + +// aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::mean(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::mean_dim::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::mean(at::DimnameList dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::mean_names_dim::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::nanmean(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::nanmean::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::median(Tensor self) -> Tensor +inline at::Tensor Tensor::median() const { + return at::_ops::median::call(const_cast(*this)); +} + +// aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::median(int64_t dim, bool keepdim) const { + return at::_ops::median_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::median(at::Dimname dim, bool keepdim) const { + return at::_ops::median_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::nanmedian(Tensor self) -> Tensor +inline at::Tensor Tensor::nanmedian() const { + return at::_ops::nanmedian::call(const_cast(*this)); +} + +// aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::nanmedian(int64_t dim, bool keepdim) const { + return at::_ops::nanmedian_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::nanmedian(at::Dimname dim, bool keepdim) const { + return at::_ops::nanmedian_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::min(int64_t dim, bool keepdim) const { + return at::_ops::min_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::min(at::Dimname dim, bool keepdim) const { + return at::_ops::min_names_dim::call(const_cast(*this), dim, keepdim); +} + +// aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +inline at::Tensor Tensor::amin(at::IntArrayRef dim, bool keepdim) const { + return at::_ops::amin::call(const_cast(*this), dim, keepdim); +} + +// aten::mm(Tensor self, Tensor mat2) -> Tensor +inline at::Tensor Tensor::mm(const at::Tensor & mat2) const { + return at::_ops::mm::call(const_cast(*this), mat2); +} + +// aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::mode(int64_t dim, bool keepdim) const { + return at::_ops::mode::call(const_cast(*this), dim, keepdim); +} + +// aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::mode(at::Dimname dim, bool keepdim) const { + return at::_ops::mode_dimname::call(const_cast(*this), dim, keepdim); +} + +// aten::mul.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::mul(const at::Tensor & other) const { + return at::_ops::mul_Tensor::call(const_cast(*this), other); +} + +// aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::mul_(const at::Tensor & other) const { + return at::_ops::mul__Tensor::call(const_cast(*this), other); +} + +// aten::mul.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::mul(const at::Scalar & other) const { + return at::_ops::mul_Scalar::call(const_cast(*this), other); +} + +// aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::mul_(const at::Scalar & other) const { + return at::_ops::mul__Scalar::call(const_cast(*this), other); +} + +// aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::multiply(const at::Tensor & other) const { + return at::_ops::multiply_Tensor::call(const_cast(*this), other); +} + +// aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::multiply_(const at::Tensor & other) const { + return at::_ops::multiply__Tensor::call(const_cast(*this), other); +} + +// aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::multiply(const at::Scalar & other) const { + return at::_ops::multiply_Scalar::call(const_cast(*this), other); +} + +// aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::multiply_(const at::Scalar & other) const { + return at::_ops::multiply__Scalar::call(const_cast(*this), other); +} + +// aten::mv(Tensor self, Tensor vec) -> Tensor +inline at::Tensor Tensor::mv(const at::Tensor & vec) const { + return at::_ops::mv::call(const_cast(*this), vec); +} + +// aten::mvlgamma(Tensor self, int p) -> Tensor +inline at::Tensor Tensor::mvlgamma(int64_t p) const { + return at::_ops::mvlgamma::call(const_cast(*this), p); +} + +// aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) +inline at::Tensor & Tensor::mvlgamma_(int64_t p) const { + return at::_ops::mvlgamma_::call(const_cast(*this), p); +} + +// aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor +inline at::Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) const { + return at::_ops::narrow_copy::call(const_cast(*this), dim, start, length); +} + +// aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor +inline at::Tensor Tensor::narrow_copy_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const { + return at::_ops::narrow_copy::call(const_cast(*this), dim, start, length); +} + +// aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { + return at::_ops::narrow::call(const_cast(*this), dim, start, length); +} + +// aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const { + return at::_ops::narrow::call(const_cast(*this), dim, start, length); +} + +// aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow(int64_t dim, const at::Tensor & start, int64_t length) const { + return at::_ops::narrow_Tensor::call(const_cast(*this), dim, start, length); +} + +// aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) +inline at::Tensor Tensor::narrow_symint(int64_t dim, const at::Tensor & start, c10::SymInt length) const { + return at::_ops::narrow_Tensor::call(const_cast(*this), dim, start, length); +} + +// aten::permute(Tensor(a) self, int[] dims) -> Tensor(a) +inline at::Tensor Tensor::permute(at::IntArrayRef dims) const { + return at::_ops::permute::call(const_cast(*this), dims); +} + +// aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) +inline at::Tensor Tensor::movedim(at::IntArrayRef source, at::IntArrayRef destination) const { + return at::_ops::movedim_intlist::call(const_cast(*this), source, destination); +} + +// aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) +inline at::Tensor Tensor::movedim(int64_t source, int64_t destination) const { + return at::_ops::movedim_int::call(const_cast(*this), source, destination); +} + +// aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) +inline at::Tensor Tensor::moveaxis(at::IntArrayRef source, at::IntArrayRef destination) const { + return at::_ops::moveaxis_intlist::call(const_cast(*this), source, destination); +} + +// aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a) +inline at::Tensor Tensor::moveaxis(int64_t source, int64_t destination) const { + return at::_ops::moveaxis_int::call(const_cast(*this), source, destination); +} + +// aten::numpy_T(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::numpy_T() const { + return at::_ops::numpy_T::call(const_cast(*this)); +} + +// aten::matrix_H(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::matrix_H() const { + return at::_ops::matrix_H::call(const_cast(*this)); +} + +// aten::mT(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::mT() const { + return at::_ops::mT::call(const_cast(*this)); +} + +// aten::mH(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::mH() const { + return at::_ops::mH::call(const_cast(*this)); +} + +// aten::adjoint(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::adjoint() const { + return at::_ops::adjoint::call(const_cast(*this)); +} + +// aten::is_pinned(Tensor self, Device? device=None) -> bool +inline bool Tensor::is_pinned(::std::optional device) const { + return at::_ops::is_pinned::call(const_cast(*this), device); +} + +// aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a) +inline at::Tensor Tensor::pin_memory(::std::optional device) const { + return at::_ops::pin_memory::call(const_cast(*this), device); +} + +// aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor +inline at::Tensor Tensor::pinverse(double rcond) const { + return at::_ops::pinverse::call(const_cast(*this), rcond); +} + +// aten::rad2deg(Tensor self) -> Tensor +inline at::Tensor Tensor::rad2deg() const { + return at::_ops::rad2deg::call(const_cast(*this)); +} + +// aten::rad2deg_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::rad2deg_() const { + return at::_ops::rad2deg_::call(const_cast(*this)); +} + +// aten::deg2rad(Tensor self) -> Tensor +inline at::Tensor Tensor::deg2rad() const { + return at::_ops::deg2rad::call(const_cast(*this)); +} + +// aten::deg2rad_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::deg2rad_() const { + return at::_ops::deg2rad_::call(const_cast(*this)); +} + +// aten::ravel(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::ravel() const { + return at::_ops::ravel::call(const_cast(*this)); +} + +// aten::reciprocal(Tensor self) -> Tensor +inline at::Tensor Tensor::reciprocal() const { + return at::_ops::reciprocal::call(const_cast(*this)); +} + +// aten::reciprocal_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::reciprocal_() const { + return at::_ops::reciprocal_::call(const_cast(*this)); +} + +// aten::neg(Tensor self) -> Tensor +inline at::Tensor Tensor::neg() const { + return at::_ops::neg::call(const_cast(*this)); +} + +// aten::neg_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::neg_() const { + return at::_ops::neg_::call(const_cast(*this)); +} + +// aten::negative(Tensor self) -> Tensor +inline at::Tensor Tensor::negative() const { + return at::_ops::negative::call(const_cast(*this)); +} + +// aten::negative_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::negative_() const { + return at::_ops::negative_::call(const_cast(*this)); +} + +// aten::repeat(Tensor self, SymInt[] repeats) -> Tensor +inline at::Tensor Tensor::repeat(at::IntArrayRef repeats) const { + return at::_ops::repeat::call(const_cast(*this), c10::fromIntArrayRefSlow(repeats)); +} + +// aten::repeat(Tensor self, SymInt[] repeats) -> Tensor +inline at::Tensor Tensor::repeat_symint(c10::SymIntArrayRef repeats) const { + return at::_ops::repeat::call(const_cast(*this), repeats); +} + +// aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave(const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_Tensor::call(const_cast(*this), repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); +} + +// aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave_symint(const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_Tensor::call(const_cast(*this), repeats, dim, output_size); +} + +// aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave(int64_t repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_int::call(const_cast(*this), repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); +} + +// aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor +inline at::Tensor Tensor::repeat_interleave_symint(c10::SymInt repeats, ::std::optional dim, ::std::optional output_size) const { + return at::_ops::repeat_interleave_self_int::call(const_cast(*this), repeats, dim, output_size); +} + +// aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) +inline at::Tensor Tensor::reshape(at::IntArrayRef shape) const { + return at::_ops::reshape::call(const_cast(*this), c10::fromIntArrayRefSlow(shape)); +} + +// aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) +inline at::Tensor Tensor::reshape_symint(c10::SymIntArrayRef shape) const { + return at::_ops::reshape::call(const_cast(*this), shape); +} + +// aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) +inline at::Tensor Tensor::_reshape_alias(at::IntArrayRef size, at::IntArrayRef stride) const { + return at::_ops::_reshape_alias::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); +} + +// aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) +inline at::Tensor Tensor::_reshape_alias_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const { + return at::_ops::_reshape_alias::call(const_cast(*this), size, stride); +} + +// aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a) +inline at::Tensor Tensor::reshape_as(const at::Tensor & other) const { + return at::_ops::reshape_as::call(const_cast(*this), other); +} + +// aten::round(Tensor self) -> Tensor +inline at::Tensor Tensor::round() const { + return at::_ops::round::call(const_cast(*this)); +} + +// aten::round_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::round_() const { + return at::_ops::round_::call(const_cast(*this)); +} + +// aten::round.decimals(Tensor self, *, int decimals) -> Tensor +inline at::Tensor Tensor::round(int64_t decimals) const { + return at::_ops::round_decimals::call(const_cast(*this), decimals); +} + +// aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!) +inline at::Tensor & Tensor::round_(int64_t decimals) const { + return at::_ops::round__decimals::call(const_cast(*this), decimals); +} + +// aten::relu(Tensor self) -> Tensor +inline at::Tensor Tensor::relu() const { + return at::_ops::relu::call(const_cast(*this)); +} + +// aten::relu_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::relu_() const { + return at::_ops::relu_::call(const_cast(*this)); +} + +// aten::prelu(Tensor self, Tensor weight) -> Tensor +inline at::Tensor Tensor::prelu(const at::Tensor & weight) const { + return at::_ops::prelu::call(const_cast(*this), weight); +} + +// aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor +inline at::Tensor Tensor::hardshrink(const at::Scalar & lambd) const { + return at::_ops::hardshrink::call(const_cast(*this), lambd); +} + +// aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor +inline at::Tensor Tensor::hardshrink_backward(const at::Tensor & grad_out, const at::Scalar & lambd) const { + return at::_ops::hardshrink_backward::call(grad_out, const_cast(*this), lambd); +} + +// aten::rsqrt(Tensor self) -> Tensor +inline at::Tensor Tensor::rsqrt() const { + return at::_ops::rsqrt::call(const_cast(*this)); +} + +// aten::rsqrt_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::rsqrt_() const { + return at::_ops::rsqrt_::call(const_cast(*this)); +} + +// aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) +inline at::Tensor Tensor::select(at::Dimname dim, int64_t index) const { + return at::_ops::select_Dimname::call(const_cast(*this), dim, index); +} + +// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) +inline at::Tensor Tensor::select(int64_t dim, int64_t index) const { + return at::_ops::select_int::call(const_cast(*this), dim, index); +} + +// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) +inline at::Tensor Tensor::select_symint(int64_t dim, c10::SymInt index) const { + return at::_ops::select_int::call(const_cast(*this), dim, index); +} + +// aten::sigmoid(Tensor self) -> Tensor +inline at::Tensor Tensor::sigmoid() const { + return at::_ops::sigmoid::call(const_cast(*this)); +} + +// aten::sigmoid_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sigmoid_() const { + return at::_ops::sigmoid_::call(const_cast(*this)); +} + +// aten::logit(Tensor self, float? eps=None) -> Tensor +inline at::Tensor Tensor::logit(::std::optional eps) const { + return at::_ops::logit::call(const_cast(*this), eps); +} + +// aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) +inline at::Tensor & Tensor::logit_(::std::optional eps) const { + return at::_ops::logit_::call(const_cast(*this), eps); +} + +// aten::sin(Tensor self) -> Tensor +inline at::Tensor Tensor::sin() const { + return at::_ops::sin::call(const_cast(*this)); +} + +// aten::sin_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sin_() const { + return at::_ops::sin_::call(const_cast(*this)); +} + +// aten::sinc(Tensor self) -> Tensor +inline at::Tensor Tensor::sinc() const { + return at::_ops::sinc::call(const_cast(*this)); +} + +// aten::sinc_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sinc_() const { + return at::_ops::sinc_::call(const_cast(*this)); +} + +// aten::sinh(Tensor self) -> Tensor +inline at::Tensor Tensor::sinh() const { + return at::_ops::sinh::call(const_cast(*this)); +} + +// aten::sinh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sinh_() const { + return at::_ops::sinh_::call(const_cast(*this)); +} + +// aten::detach(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::detach() const { + return at::_ops::detach::call(const_cast(*this)); +} + +// aten::detach_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::detach_() const { + return at::_ops::detach_::call(const_cast(*this)); +} + +// aten::size.Dimname(Tensor self, Dimname dim) -> int +inline int64_t Tensor::size(at::Dimname dim) const { + return at::_ops::size_Dimname::call(const_cast(*this), dim); +} + +// aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice(int64_t dim, ::std::optional start, ::std::optional end, int64_t step) const { + return at::_ops::slice_Tensor::call(const_cast(*this), dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); +} + +// aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice_symint(int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) const { + return at::_ops::slice_Tensor::call(const_cast(*this), dim, start, end, step); +} + +// aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice_inverse(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, int64_t step) const { + return at::_ops::slice_inverse::call(const_cast(*this), src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); +} + +// aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) +inline at::Tensor Tensor::slice_inverse_symint(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) const { + return at::_ops::slice_inverse::call(const_cast(*this), src, dim, start, end, step); +} + +// aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor +inline at::Tensor Tensor::slice_scatter(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, int64_t step) const { + return at::_ops::slice_scatter::call(const_cast(*this), src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); +} + +// aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor +inline at::Tensor Tensor::slice_scatter_symint(const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) const { + return at::_ops::slice_scatter::call(const_cast(*this), src, dim, start, end, step); +} + +// aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor +inline at::Tensor Tensor::select_scatter(const at::Tensor & src, int64_t dim, int64_t index) const { + return at::_ops::select_scatter::call(const_cast(*this), src, dim, index); +} + +// aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor +inline at::Tensor Tensor::select_scatter_symint(const at::Tensor & src, int64_t dim, c10::SymInt index) const { + return at::_ops::select_scatter::call(const_cast(*this), src, dim, index); +} + +// aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor +inline at::Tensor Tensor::diagonal_scatter(const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2) const { + return at::_ops::diagonal_scatter::call(const_cast(*this), src, offset, dim1, dim2); +} + +// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor Tensor::as_strided_scatter(const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_scatter::call(const_cast(*this), src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); +} + +// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor +inline at::Tensor Tensor::as_strided_scatter_symint(const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) const { + return at::_ops::as_strided_scatter::call(const_cast(*this), src, size, stride, storage_offset); +} + +// aten::smm(Tensor self, Tensor mat2) -> Tensor +inline at::Tensor Tensor::smm(const at::Tensor & mat2) const { + return at::_ops::smm::call(const_cast(*this), mat2); +} + +// aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::softmax(int64_t dim, ::std::optional dtype) const { + return at::_ops::softmax_int::call(const_cast(*this), dim, dtype); +} + +// aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::softmax(at::Dimname dim, ::std::optional dtype) const { + return at::_ops::softmax_Dimname::call(const_cast(*this), dim, dtype); +} + +// aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split(int64_t split_size, int64_t dim) const { + return at::_ops::unsafe_split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split_symint(c10::SymInt split_size, int64_t dim) const { + return at::_ops::unsafe_split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split(int64_t split_size, int64_t dim) const { + return at::_ops::split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_symint(c10::SymInt split_size, int64_t dim) const { + return at::_ops::split_Tensor::call(const_cast(*this), split_size, dim); +} + +// aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split(at::IntArrayRef split_size, int64_t dim) const { + return at::_ops::split_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_size), dim); +} + +// aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_symint(c10::SymIntArrayRef split_size, int64_t dim) const { + return at::_ops::split_sizes::call(const_cast(*this), split_size, dim); +} + +// aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split_with_sizes(at::IntArrayRef split_sizes, int64_t dim) const { + return at::_ops::unsafe_split_with_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_sizes), dim); +} + +// aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] +inline ::std::vector Tensor::unsafe_split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim) const { + return at::_ops::unsafe_split_with_sizes::call(const_cast(*this), split_sizes, dim); +} + +// aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_with_sizes(at::IntArrayRef split_sizes, int64_t dim) const { + return at::_ops::split_with_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_sizes), dim); +} + +// aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim) const { + return at::_ops::split_with_sizes::call(const_cast(*this), split_sizes, dim); +} + +// aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] +inline ::std::vector Tensor::hsplit(int64_t sections) const { + return at::_ops::hsplit_int::call(const_cast(*this), sections); +} + +// aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] +inline ::std::vector Tensor::hsplit(at::IntArrayRef indices) const { + return at::_ops::hsplit_array::call(const_cast(*this), indices); +} + +// aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] +inline ::std::vector Tensor::vsplit(int64_t sections) const { + return at::_ops::vsplit_int::call(const_cast(*this), sections); +} + +// aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] +inline ::std::vector Tensor::vsplit(at::IntArrayRef indices) const { + return at::_ops::vsplit_array::call(const_cast(*this), indices); +} + +// aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] +inline ::std::vector Tensor::dsplit(int64_t sections) const { + return at::_ops::dsplit_int::call(const_cast(*this), sections); +} + +// aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] +inline ::std::vector Tensor::dsplit(at::IntArrayRef indices) const { + return at::_ops::dsplit_array::call(const_cast(*this), indices); +} + +// aten::squeeze(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::squeeze() const { + return at::_ops::squeeze::call(const_cast(*this)); +} + +// aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) +inline at::Tensor Tensor::squeeze(int64_t dim) const { + return at::_ops::squeeze_dim::call(const_cast(*this), dim); +} + +// aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a) +inline at::Tensor Tensor::squeeze(at::Dimname dim) const { + return at::_ops::squeeze_dimname::call(const_cast(*this), dim); +} + +// aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) +inline at::Tensor Tensor::squeeze(at::IntArrayRef dim) const { + return at::_ops::squeeze_dims::call(const_cast(*this), dim); +} + +// aten::squeeze_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_() const { + return at::_ops::squeeze_::call(const_cast(*this)); +} + +// aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_(int64_t dim) const { + return at::_ops::squeeze__dim::call(const_cast(*this), dim); +} + +// aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_(at::IntArrayRef dim) const { + return at::_ops::squeeze__dims::call(const_cast(*this), dim); +} + +// aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) +inline at::Tensor & Tensor::squeeze_(at::Dimname dim) const { + return at::_ops::squeeze__dimname::call(const_cast(*this), dim); +} + +// aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::sspaddmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::sspaddmm::call(const_cast(*this), mat1, mat2, beta, alpha); +} + +// aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor +inline at::Tensor Tensor::stft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) const { + return at::_ops::stft::call(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided, return_complex, align_to_window); +} + +// aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor +inline at::Tensor Tensor::stft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, c10::string_view pad_mode, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) const { + return at::_ops::stft_center::call(const_cast(*this), n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex, align_to_window); +} + +// aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor +inline at::Tensor Tensor::istft(int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, bool normalized, ::std::optional onesided, ::std::optional length, bool return_complex) const { + return at::_ops::istft::call(const_cast(*this), n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex); +} + +// aten::stride.Dimname(Tensor self, Dimname dim) -> int +inline int64_t Tensor::stride(at::Dimname dim) const { + return at::_ops::stride_Dimname::call(const_cast(*this), dim); +} + +// aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::sum(::std::optional dtype) const { + return at::_ops::sum::call(const_cast(*this), dtype); +} + +// aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::sum(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::sum_dim_IntList::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::sum(at::DimnameList dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::sum_dim_DimnameList::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::nansum(at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::nansum::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor +inline at::Tensor Tensor::sum_to_size(at::IntArrayRef size) const { + return at::_ops::sum_to_size::call(const_cast(*this), c10::fromIntArrayRefSlow(size)); +} + +// aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor +inline at::Tensor Tensor::sum_to_size_symint(c10::SymIntArrayRef size) const { + return at::_ops::sum_to_size::call(const_cast(*this), size); +} + +// aten::sqrt(Tensor self) -> Tensor +inline at::Tensor Tensor::sqrt() const { + return at::_ops::sqrt::call(const_cast(*this)); +} + +// aten::sqrt_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sqrt_() const { + return at::_ops::sqrt_::call(const_cast(*this)); +} + +// aten::square(Tensor self) -> Tensor +inline at::Tensor Tensor::square() const { + return at::_ops::square::call(const_cast(*this)); +} + +// aten::square_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::square_() const { + return at::_ops::square_::call(const_cast(*this)); +} + +// aten::std(Tensor self, bool unbiased=True) -> Tensor +inline at::Tensor Tensor::std(bool unbiased) const { + return at::_ops::std::call(const_cast(*this), unbiased); +} + +// aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) const { + return at::_ops::std_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::std_correction::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::DimnameList dim, bool unbiased, bool keepdim) const { + return at::_ops::std_names_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::std(at::DimnameList dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::std_correction_names::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::prod(::std::optional dtype) const { + return at::_ops::prod::call(const_cast(*this), dtype); +} + +// aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::prod(int64_t dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::prod_dim_int::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::prod(at::Dimname dim, bool keepdim, ::std::optional dtype) const { + return at::_ops::prod_dim_Dimname::call(const_cast(*this), dim, keepdim, dtype); +} + +// aten::t(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::t() const { + return at::_ops::t::call(const_cast(*this)); +} + +// aten::t_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::t_() const { + return at::_ops::t_::call(const_cast(*this)); +} + +// aten::tan(Tensor self) -> Tensor +inline at::Tensor Tensor::tan() const { + return at::_ops::tan::call(const_cast(*this)); +} + +// aten::tan_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::tan_() const { + return at::_ops::tan_::call(const_cast(*this)); +} + +// aten::tanh(Tensor self) -> Tensor +inline at::Tensor Tensor::tanh() const { + return at::_ops::tanh::call(const_cast(*this)); +} + +// aten::tanh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::tanh_() const { + return at::_ops::tanh_::call(const_cast(*this)); +} + +// aten::tile(Tensor self, SymInt[] dims) -> Tensor +inline at::Tensor Tensor::tile(at::IntArrayRef dims) const { + return at::_ops::tile::call(const_cast(*this), c10::fromIntArrayRefSlow(dims)); +} + +// aten::tile(Tensor self, SymInt[] dims) -> Tensor +inline at::Tensor Tensor::tile_symint(c10::SymIntArrayRef dims) const { + return at::_ops::tile::call(const_cast(*this), dims); +} + +// aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) +inline at::Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { + return at::_ops::transpose_int::call(const_cast(*this), dim0, dim1); +} + +// aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) +inline at::Tensor Tensor::transpose(at::Dimname dim0, at::Dimname dim1) const { + return at::_ops::transpose_Dimname::call(const_cast(*this), dim0, dim1); +} + +// aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) +inline at::Tensor & Tensor::transpose_(int64_t dim0, int64_t dim1) const { + return at::_ops::transpose_::call(const_cast(*this), dim0, dim1); +} + +// aten::flip(Tensor self, int[] dims) -> Tensor +inline at::Tensor Tensor::flip(at::IntArrayRef dims) const { + return at::_ops::flip::call(const_cast(*this), dims); +} + +// aten::fliplr(Tensor self) -> Tensor +inline at::Tensor Tensor::fliplr() const { + return at::_ops::fliplr::call(const_cast(*this)); +} + +// aten::flipud(Tensor self) -> Tensor +inline at::Tensor Tensor::flipud() const { + return at::_ops::flipud::call(const_cast(*this)); +} + +// aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor +inline at::Tensor Tensor::roll(at::IntArrayRef shifts, at::IntArrayRef dims) const { + return at::_ops::roll::call(const_cast(*this), c10::fromIntArrayRefSlow(shifts), dims); +} + +// aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor +inline at::Tensor Tensor::roll_symint(c10::SymIntArrayRef shifts, at::IntArrayRef dims) const { + return at::_ops::roll::call(const_cast(*this), shifts, dims); +} + +// aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor +inline at::Tensor Tensor::rot90(int64_t k, at::IntArrayRef dims) const { + return at::_ops::rot90::call(const_cast(*this), k, dims); +} + +// aten::_nested_tensor_size(Tensor self) -> Tensor +inline at::Tensor Tensor::_nested_tensor_size() const { + return at::_ops::_nested_tensor_size::call(const_cast(*this)); +} + +// aten::_nested_tensor_strides(Tensor self) -> Tensor +inline at::Tensor Tensor::_nested_tensor_strides() const { + return at::_ops::_nested_tensor_strides::call(const_cast(*this)); +} + +// aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor +inline at::Tensor Tensor::_nested_tensor_storage_offsets() const { + return at::_ops::_nested_tensor_storage_offsets::call(const_cast(*this)); +} + +// aten::trunc(Tensor self) -> Tensor +inline at::Tensor Tensor::trunc() const { + return at::_ops::trunc::call(const_cast(*this)); +} + +// aten::trunc_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::trunc_() const { + return at::_ops::trunc_::call(const_cast(*this)); +} + +// aten::fix(Tensor self) -> Tensor +inline at::Tensor Tensor::fix() const { + return at::_ops::fix::call(const_cast(*this)); +} + +// aten::fix_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::fix_() const { + return at::_ops::fix_::call(const_cast(*this)); +} + +// aten::type_as(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::type_as(const at::Tensor & other) const { + return at::_ops::type_as::call(const_cast(*this), other); +} + +// aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a) +inline at::Tensor Tensor::unsqueeze(int64_t dim) const { + return at::_ops::unsqueeze::call(const_cast(*this), dim); +} + +// aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) +inline at::Tensor & Tensor::unsqueeze_(int64_t dim) const { + return at::_ops::unsqueeze_::call(const_cast(*this), dim); +} + +// aten::var(Tensor self, bool unbiased=True) -> Tensor +inline at::Tensor Tensor::var(bool unbiased) const { + return at::_ops::var::call(const_cast(*this), unbiased); +} + +// aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) const { + return at::_ops::var_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::var_correction::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::DimnameList dim, bool unbiased, bool keepdim) const { + return at::_ops::var_names_dim::call(const_cast(*this), dim, unbiased, keepdim); +} + +// aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::var(at::DimnameList dim, const ::std::optional & correction, bool keepdim) const { + return at::_ops::var_correction_names::call(const_cast(*this), dim, correction, keepdim); +} + +// aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a) +inline at::Tensor Tensor::view_as(const at::Tensor & other) const { + return at::_ops::view_as::call(const_cast(*this), other); +} + +// aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::where(const at::Tensor & condition, const at::Tensor & other) const { + return at::_ops::where_self::call(condition, const_cast(*this), other); +} + +// aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::where(const at::Tensor & condition, const at::Scalar & other) const { + return at::_ops::where_ScalarOther::call(condition, const_cast(*this), other); +} + +// aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::ScalarType dtype) const { + return at::_ops::norm_ScalarOpt_dtype::call(const_cast(*this), p, dtype); +} + +// aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor +inline at::Tensor Tensor::norm(const at::Scalar & p) const { + return at::_ops::norm_Scalar::call(const_cast(*this), p); +} + +// aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) const { + return at::_ops::norm_ScalarOpt_dim_dtype::call(const_cast(*this), p, dim, keepdim, dtype); +} + +// aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::IntArrayRef dim, bool keepdim) const { + return at::_ops::norm_ScalarOpt_dim::call(const_cast(*this), p, dim, keepdim); +} + +// aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) const { + return at::_ops::norm_names_ScalarOpt_dim_dtype::call(const_cast(*this), p, dim, keepdim, dtype); +} + +// aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor +inline at::Tensor Tensor::norm(const ::std::optional & p, at::DimnameList dim, bool keepdim) const { + return at::_ops::norm_names_ScalarOpt_dim::call(const_cast(*this), p, dim, keepdim); +} + +// aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) +inline ::std::tuple Tensor::frexp() const { + return at::_ops::frexp_Tensor::call(const_cast(*this)); +} + +// aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor +inline at::Tensor Tensor::clone(::std::optional memory_format) const { + return at::_ops::clone::call(const_cast(*this), memory_format); +} + +// aten::positive(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::positive() const { + return at::_ops::positive::call(const_cast(*this)); +} + +// aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) +inline const at::Tensor & Tensor::resize_as_(const at::Tensor & the_template, ::std::optional memory_format) const { + return at::_ops::resize_as_::call(const_cast(*this), the_template, memory_format); +} + +// aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) +inline const at::Tensor & Tensor::resize_as_sparse_(const at::Tensor & the_template) const { + return at::_ops::resize_as_sparse_::call(const_cast(*this), the_template); +} + +// aten::zero_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::zero_() const { + return at::_ops::zero_::call(const_cast(*this)); +} + +// aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::sub(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::sub_Tensor::call(const_cast(*this), other, alpha); +} + +// aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::sub_(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::sub__Tensor::call(const_cast(*this), other, alpha); +} + +// aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::sub(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::sub_Scalar::call(const_cast(*this), other, alpha); +} + +// aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::sub_(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::sub__Scalar::call(const_cast(*this), other, alpha); +} + +// aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::subtract(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::subtract_Tensor::call(const_cast(*this), other, alpha); +} + +// aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::subtract_(const at::Tensor & other, const at::Scalar & alpha) const { + return at::_ops::subtract__Tensor::call(const_cast(*this), other, alpha); +} + +// aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::subtract(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::subtract_Scalar::call(const_cast(*this), other, alpha); +} + +// aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::subtract_(const at::Scalar & other, const at::Scalar & alpha) const { + return at::_ops::subtract__Scalar::call(const_cast(*this), other, alpha); +} + +// aten::heaviside(Tensor self, Tensor values) -> Tensor +inline at::Tensor Tensor::heaviside(const at::Tensor & values) const { + return at::_ops::heaviside::call(const_cast(*this), values); +} + +// aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) +inline at::Tensor & Tensor::heaviside_(const at::Tensor & values) const { + return at::_ops::heaviside_::call(const_cast(*this), values); +} + +// aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmm::call(const_cast(*this), mat1, mat2, beta, alpha); +} + +// aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addmm_(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addmm_::call(const_cast(*this), mat1, mat2, beta, alpha); +} + +// aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor +inline at::Tensor Tensor::_addmm_activation(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu) const { + return at::_ops::_addmm_activation::call(const_cast(*this), mat1, mat2, beta, alpha, use_gelu); +} + +// aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) +inline const at::Tensor & Tensor::sparse_resize_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { + return at::_ops::sparse_resize_::call(const_cast(*this), size, sparse_dim, dense_dim); +} + +// aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) +inline const at::Tensor & Tensor::sparse_resize_and_clear_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { + return at::_ops::sparse_resize_and_clear_::call(const_cast(*this), size, sparse_dim, dense_dim); +} + +// aten::sparse_mask(Tensor self, Tensor mask) -> Tensor +inline at::Tensor Tensor::sparse_mask(const at::Tensor & mask) const { + return at::_ops::sparse_mask::call(const_cast(*this), mask); +} + +// aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor +inline at::Tensor Tensor::_sparse_mask_projection(const at::Tensor & mask, bool accumulate_matches) const { + return at::_ops::_sparse_mask_projection::call(const_cast(*this), mask, accumulate_matches); +} + +// aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor +inline at::Tensor Tensor::to_dense(::std::optional dtype, ::std::optional masked_grad) const { + return at::_ops::to_dense::call(const_cast(*this), dtype, masked_grad); +} + +// aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor +inline at::Tensor Tensor::_to_dense(::std::optional dtype, ::std::optional masked_grad) const { + return at::_ops::_to_dense::call(const_cast(*this), dtype, masked_grad); +} + +// aten::sparse_dim(Tensor self) -> int +inline int64_t Tensor::sparse_dim() const { + return at::_ops::sparse_dim::call(const_cast(*this)); +} + +// aten::_dimI(Tensor self) -> int +inline int64_t Tensor::_dimI() const { + return at::_ops::_dimI::call(const_cast(*this)); +} + +// aten::dense_dim(Tensor self) -> int +inline int64_t Tensor::dense_dim() const { + return at::_ops::dense_dim::call(const_cast(*this)); +} + +// aten::_dimV(Tensor self) -> int +inline int64_t Tensor::_dimV() const { + return at::_ops::_dimV::call(const_cast(*this)); +} + +// aten::_nnz(Tensor self) -> int +inline int64_t Tensor::_nnz() const { + return at::_ops::_nnz::call(const_cast(*this)); +} + +// aten::coalesce(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::coalesce() const { + return at::_ops::coalesce::call(const_cast(*this)); +} + +// aten::is_coalesced(Tensor self) -> bool +inline bool Tensor::is_coalesced() const { + return at::_ops::is_coalesced::call(const_cast(*this)); +} + +// aten::_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_indices() const { + return at::_ops::_indices::call(const_cast(*this)); +} + +// aten::_values(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::_values() const { + return at::_ops::_values::call(const_cast(*this)); +} + +// aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) +inline at::Tensor & Tensor::_coalesced_(bool coalesced) const { + return at::_ops::_coalesced_::call(const_cast(*this), coalesced); +} + +// aten::indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::indices() const { + return at::_ops::indices::call(const_cast(*this)); +} + +// aten::values(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::values() const { + return at::_ops::values::call(const_cast(*this)); +} + +// aten::crow_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::crow_indices() const { + return at::_ops::crow_indices::call(const_cast(*this)); +} + +// aten::col_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::col_indices() const { + return at::_ops::col_indices::call(const_cast(*this)); +} + +// aten::ccol_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::ccol_indices() const { + return at::_ops::ccol_indices::call(const_cast(*this)); +} + +// aten::row_indices(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::row_indices() const { + return at::_ops::row_indices::call(const_cast(*this)); +} + +// aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] +inline ::std::vector Tensor::unbind(int64_t dim) const { + return at::_ops::unbind_int::call(const_cast(*this), dim); +} + +// aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[] +inline ::std::vector Tensor::unbind(at::Dimname dim) const { + return at::_ops::unbind_Dimname::call(const_cast(*this), dim); +} + +// aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor +inline at::Tensor Tensor::to_sparse(int64_t sparse_dim) const { + return at::_ops::to_sparse_sparse_dim::call(const_cast(*this), sparse_dim); +} + +// aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor +inline at::Tensor Tensor::_to_sparse(int64_t sparse_dim) const { + return at::_ops::_to_sparse_sparse_dim::call(const_cast(*this), sparse_dim); +} + +// aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse(::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::to_sparse::call(const_cast(*this), layout, blocksize, dense_dim); +} + +// aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse(::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::_to_sparse::call(const_cast(*this), layout, blocksize, dense_dim); +} + +// aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_csr(::std::optional dense_dim) const { + return at::_ops::to_sparse_csr::call(const_cast(*this), dense_dim); +} + +// aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_csr(::std::optional dense_dim) const { + return at::_ops::_to_sparse_csr::call(const_cast(*this), dense_dim); +} + +// aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_csc(::std::optional dense_dim) const { + return at::_ops::to_sparse_csc::call(const_cast(*this), dense_dim); +} + +// aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_csc(::std::optional dense_dim) const { + return at::_ops::_to_sparse_csc::call(const_cast(*this), dense_dim); +} + +// aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::to_sparse_bsr::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_bsr(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::_to_sparse_bsr::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::to_sparse_bsc::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +inline at::Tensor Tensor::_to_sparse_bsc(at::IntArrayRef blocksize, ::std::optional dense_dim) const { + return at::_ops::_to_sparse_bsc::call(const_cast(*this), blocksize, dense_dim); +} + +// aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor +inline at::Tensor Tensor::to_mkldnn(::std::optional dtype) const { + return at::_ops::to_mkldnn::call(const_cast(*this), dtype); +} + +// aten::dequantize.self(Tensor self) -> Tensor +inline at::Tensor Tensor::dequantize() const { + return at::_ops::dequantize_self::call(const_cast(*this)); +} + +// aten::q_scale(Tensor self) -> float +inline double Tensor::q_scale() const { + return at::_ops::q_scale::call(const_cast(*this)); +} + +// aten::q_zero_point(Tensor self) -> int +inline int64_t Tensor::q_zero_point() const { + return at::_ops::q_zero_point::call(const_cast(*this)); +} + +// aten::q_per_channel_scales(Tensor self) -> Tensor +inline at::Tensor Tensor::q_per_channel_scales() const { + return at::_ops::q_per_channel_scales::call(const_cast(*this)); +} + +// aten::q_per_channel_zero_points(Tensor self) -> Tensor +inline at::Tensor Tensor::q_per_channel_zero_points() const { + return at::_ops::q_per_channel_zero_points::call(const_cast(*this)); +} + +// aten::q_per_channel_axis(Tensor self) -> int +inline int64_t Tensor::q_per_channel_axis() const { + return at::_ops::q_per_channel_axis::call(const_cast(*this)); +} + +// aten::int_repr(Tensor self) -> Tensor +inline at::Tensor Tensor::int_repr() const { + return at::_ops::int_repr::call(const_cast(*this)); +} + +// aten::qscheme(Tensor self) -> QScheme +inline at::QScheme Tensor::qscheme() const { + return at::_ops::qscheme::call(const_cast(*this)); +} + +// aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a) +inline at::Tensor Tensor::_autocast_to_reduced_precision(bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) const { + return at::_ops::_autocast_to_reduced_precision::call(const_cast(*this), cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); +} + +// aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a) +inline at::Tensor Tensor::_autocast_to_full_precision(bool cuda_enabled, bool cpu_enabled) const { + return at::_ops::_autocast_to_full_precision::call(const_cast(*this), cuda_enabled, cpu_enabled); +} + +// aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(at::TensorOptions options, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_dtype_layout::call(const_cast(*this), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, copy, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); +} + +// aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_dtype_layout::call(const_cast(*this), dtype, layout, device, pin_memory, non_blocking, copy, memory_format); +} + +// aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_device::call(const_cast(*this), device, dtype, non_blocking, copy, memory_format); +} + +// aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_dtype::call(const_cast(*this), dtype, non_blocking, copy, memory_format); +} + +// aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) +inline at::Tensor Tensor::to(const at::Tensor & other, bool non_blocking, bool copy, ::std::optional memory_format) const { + return at::_ops::to_other::call(const_cast(*this), other, non_blocking, copy, memory_format); +} + +// aten::item(Tensor self) -> Scalar +inline at::Scalar Tensor::item() const { + return at::_ops::item::call(const_cast(*this)); +} + +// aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!) +inline at::Tensor & Tensor::set_(at::Storage source) const { + return at::_ops::set__source_Storage::call(const_cast(*this), source); +} + +// aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set_(at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride) const { + return at::_ops::set__source_Storage_storage_offset::call(const_cast(*this), source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); +} + +// aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set__symint(at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const { + return at::_ops::set__source_Storage_storage_offset::call(const_cast(*this), source, storage_offset, size, stride); +} + +// aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set_(const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride) const { + return at::_ops::set__source_Tensor_storage_offset::call(const_cast(*this), source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); +} + +// aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) +inline at::Tensor & Tensor::set__symint(const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const { + return at::_ops::set__source_Tensor_storage_offset::call(const_cast(*this), source, storage_offset, size, stride); +} + +// aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::set_(const at::Tensor & source) const { + return at::_ops::set__source_Tensor::call(const_cast(*this), source); +} + +// aten::set_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::set_() const { + return at::_ops::set_::call(const_cast(*this)); +} + +// aten::is_set_to(Tensor self, Tensor tensor) -> bool +inline bool Tensor::is_set_to(const at::Tensor & tensor) const { + return at::_ops::is_set_to::call(const_cast(*this), tensor); +} + +// aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::masked_fill_(const at::Tensor & mask, const at::Scalar & value) const { + return at::_ops::masked_fill__Scalar::call(const_cast(*this), mask, value); +} + +// aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor +inline at::Tensor Tensor::masked_fill(const at::Tensor & mask, const at::Scalar & value) const { + return at::_ops::masked_fill_Scalar::call(const_cast(*this), mask, value); +} + +// aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::masked_fill_(const at::Tensor & mask, const at::Tensor & value) const { + return at::_ops::masked_fill__Tensor::call(const_cast(*this), mask, value); +} + +// aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor +inline at::Tensor Tensor::masked_fill(const at::Tensor & mask, const at::Tensor & value) const { + return at::_ops::masked_fill_Tensor::call(const_cast(*this), mask, value); +} + +// aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) +inline at::Tensor & Tensor::masked_scatter_(const at::Tensor & mask, const at::Tensor & source) const { + return at::_ops::masked_scatter_::call(const_cast(*this), mask, source); +} + +// aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor +inline at::Tensor Tensor::masked_scatter(const at::Tensor & mask, const at::Tensor & source) const { + return at::_ops::masked_scatter::call(const_cast(*this), mask, source); +} + +// aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::view(at::IntArrayRef size) const { + return at::_ops::view::call(const_cast(*this), c10::fromIntArrayRefSlow(size)); +} + +// aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) +inline at::Tensor Tensor::view_symint(c10::SymIntArrayRef size) const { + return at::_ops::view::call(const_cast(*this), size); +} + +// aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) +inline at::Tensor Tensor::view(at::ScalarType dtype) const { + return at::_ops::view_dtype::call(const_cast(*this), dtype); +} + +// aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) +inline at::Tensor & Tensor::put_(const at::Tensor & index, const at::Tensor & source, bool accumulate) const { + return at::_ops::put_::call(const_cast(*this), index, source, accumulate); +} + +// aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor +inline at::Tensor Tensor::put(const at::Tensor & index, const at::Tensor & source, bool accumulate) const { + return at::_ops::put::call(const_cast(*this), index, source, accumulate); +} + +// aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::index_add_(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const { + return at::_ops::index_add_::call(const_cast(*this), dim, index, source, alpha); +} + +// aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::index_add(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const { + return at::_ops::index_add::call(const_cast(*this), dim, index, source, alpha); +} + +// aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::index_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const { + return at::_ops::index_add_dimname::call(const_cast(*this), dim, index, source, alpha); +} + +// aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) +inline at::Tensor & Tensor::index_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) const { + return at::_ops::index_reduce_::call(const_cast(*this), dim, index, source, reduce, include_self); +} + +// aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor +inline at::Tensor Tensor::index_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) const { + return at::_ops::index_reduce::call(const_cast(*this), dim, index, source, reduce, include_self); +} + +// aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill__int_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::index_fill(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill_int_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(int64_t dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill__int_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor +inline at::Tensor Tensor::index_fill(int64_t dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill_int_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill__Dimname_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!) +inline at::Tensor & Tensor::index_fill_(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill__Dimname_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::index_fill(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::index_fill_Dimname_Scalar::call(const_cast(*this), dim, index, value); +} + +// aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor +inline at::Tensor Tensor::index_fill(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const { + return at::_ops::index_fill_Dimname_Tensor::call(const_cast(*this), dim, index, value); +} + +// aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_src::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter__src::call(const_cast(*this), dim, index, src); +} + +// aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::scatter_value::call(const_cast(*this), dim, index, value); +} + +// aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::scatter__value::call(const_cast(*this), dim, index, value); +} + +// aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const { + return at::_ops::scatter_reduce::call(const_cast(*this), dim, index, src, reduce); +} + +// aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const { + return at::_ops::scatter__reduce::call(const_cast(*this), dim, index, src, reduce); +} + +// aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor +inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const { + return at::_ops::scatter_value_reduce::call(const_cast(*this), dim, index, value, reduce); +} + +// aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const { + return at::_ops::scatter__value_reduce::call(const_cast(*this), dim, index, value, reduce); +} + +// aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_dimname_src::call(const_cast(*this), dim, index, src); +} + +// aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor +inline at::Tensor Tensor::scatter(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const { + return at::_ops::scatter_dimname_value::call(const_cast(*this), dim, index, value); +} + +// aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter_add(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_add::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_add_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_add_::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor +inline at::Tensor Tensor::scatter_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const { + return at::_ops::scatter_add_dimname::call(const_cast(*this), dim, index, src); +} + +// aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor +inline at::Tensor Tensor::scatter_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) const { + return at::_ops::scatter_reduce_two::call(const_cast(*this), dim, index, src, reduce, include_self); +} + +// aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!) +inline at::Tensor & Tensor::scatter_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) const { + return at::_ops::scatter_reduce__two::call(const_cast(*this), dim, index, src, reduce, include_self); +} + +// aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::eq_(const at::Scalar & other) const { + return at::_ops::eq__Scalar::call(const_cast(*this), other); +} + +// aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::eq_(const at::Tensor & other) const { + return at::_ops::eq__Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_and(const at::Scalar & other) const { + return at::_ops::bitwise_and_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_and(const at::Tensor & other) const { + return at::_ops::bitwise_and_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_and_(const at::Scalar & other) const { + return at::_ops::bitwise_and__Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_and_(const at::Tensor & other) const { + return at::_ops::bitwise_and__Tensor::call(const_cast(*this), other); +} + +// aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__and__(const at::Scalar & other) const { + return at::_ops::__and___Scalar::call(const_cast(*this), other); +} + +// aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__and__(const at::Tensor & other) const { + return at::_ops::__and___Tensor::call(const_cast(*this), other); +} + +// aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__iand__(const at::Scalar & other) const { + return at::_ops::__iand___Scalar::call(const_cast(*this), other); +} + +// aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__iand__(const at::Tensor & other) const { + return at::_ops::__iand___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_or(const at::Scalar & other) const { + return at::_ops::bitwise_or_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_or(const at::Tensor & other) const { + return at::_ops::bitwise_or_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_or_(const at::Scalar & other) const { + return at::_ops::bitwise_or__Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_or_(const at::Tensor & other) const { + return at::_ops::bitwise_or__Tensor::call(const_cast(*this), other); +} + +// aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__or__(const at::Scalar & other) const { + return at::_ops::__or___Scalar::call(const_cast(*this), other); +} + +// aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__or__(const at::Tensor & other) const { + return at::_ops::__or___Tensor::call(const_cast(*this), other); +} + +// aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__ior__(const at::Scalar & other) const { + return at::_ops::__ior___Scalar::call(const_cast(*this), other); +} + +// aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__ior__(const at::Tensor & other) const { + return at::_ops::__ior___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_xor(const at::Scalar & other) const { + return at::_ops::bitwise_xor_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_xor(const at::Tensor & other) const { + return at::_ops::bitwise_xor_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_xor_(const at::Scalar & other) const { + return at::_ops::bitwise_xor__Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_xor_(const at::Tensor & other) const { + return at::_ops::bitwise_xor__Tensor::call(const_cast(*this), other); +} + +// aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__xor__(const at::Scalar & other) const { + return at::_ops::__xor___Scalar::call(const_cast(*this), other); +} + +// aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__xor__(const at::Tensor & other) const { + return at::_ops::__xor___Tensor::call(const_cast(*this), other); +} + +// aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__ixor__(const at::Scalar & other) const { + return at::_ops::__ixor___Scalar::call(const_cast(*this), other); +} + +// aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__ixor__(const at::Tensor & other) const { + return at::_ops::__ixor___Tensor::call(const_cast(*this), other); +} + +// aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__lshift__(const at::Scalar & other) const { + return at::_ops::__lshift___Scalar::call(const_cast(*this), other); +} + +// aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__lshift__(const at::Tensor & other) const { + return at::_ops::__lshift___Tensor::call(const_cast(*this), other); +} + +// aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__ilshift__(const at::Scalar & other) const { + return at::_ops::__ilshift___Scalar::call(const_cast(*this), other); +} + +// aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__ilshift__(const at::Tensor & other) const { + return at::_ops::__ilshift___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_left_shift(const at::Tensor & other) const { + return at::_ops::bitwise_left_shift_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_left_shift_(const at::Tensor & other) const { + return at::_ops::bitwise_left_shift__Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_left_shift(const at::Scalar & other) const { + return at::_ops::bitwise_left_shift_Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_left_shift_(const at::Scalar & other) const { + return at::_ops::bitwise_left_shift__Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::__rshift__(const at::Scalar & other) const { + return at::_ops::__rshift___Scalar::call(const_cast(*this), other); +} + +// aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::__rshift__(const at::Tensor & other) const { + return at::_ops::__rshift___Tensor::call(const_cast(*this), other); +} + +// aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::__irshift__(const at::Scalar & other) const { + return at::_ops::__irshift___Scalar::call(const_cast(*this), other); +} + +// aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::__irshift__(const at::Tensor & other) const { + return at::_ops::__irshift___Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::bitwise_right_shift(const at::Tensor & other) const { + return at::_ops::bitwise_right_shift_Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_right_shift_(const at::Tensor & other) const { + return at::_ops::bitwise_right_shift__Tensor::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::bitwise_right_shift(const at::Scalar & other) const { + return at::_ops::bitwise_right_shift_Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::bitwise_right_shift_(const at::Scalar & other) const { + return at::_ops::bitwise_right_shift__Tensor_Scalar::call(const_cast(*this), other); +} + +// aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +inline at::Tensor & Tensor::tril_(int64_t diagonal) const { + return at::_ops::tril_::call(const_cast(*this), diagonal); +} + +// aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +inline at::Tensor & Tensor::triu_(int64_t diagonal) const { + return at::_ops::triu_::call(const_cast(*this), diagonal); +} + +// aten::digamma_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::digamma_() const { + return at::_ops::digamma_::call(const_cast(*this)); +} + +// aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) +inline at::Tensor & Tensor::lerp_(const at::Tensor & end, const at::Scalar & weight) const { + return at::_ops::lerp__Scalar::call(const_cast(*this), end, weight); +} + +// aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) +inline at::Tensor & Tensor::lerp_(const at::Tensor & end, const at::Tensor & weight) const { + return at::_ops::lerp__Tensor::call(const_cast(*this), end, weight); +} + +// aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) +inline at::Tensor & Tensor::addbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addbmm_::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor Tensor::addbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const { + return at::_ops::addbmm::call(const_cast(*this), batch1, batch2, beta, alpha); +} + +// aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::random_(int64_t from, ::std::optional to, ::std::optional generator) const { + return at::_ops::random__from::call(const_cast(*this), from, to, generator); +} + +// aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::random_(int64_t to, ::std::optional generator) const { + return at::_ops::random__to::call(const_cast(*this), to, generator); +} + +// aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::random_(::std::optional generator) const { + return at::_ops::random_::call(const_cast(*this), generator); +} + +// aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::uniform_(double from, double to, ::std::optional generator) const { + return at::_ops::uniform_::call(const_cast(*this), from, to, generator); +} + +// aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::cauchy_(double median, double sigma, ::std::optional generator) const { + return at::_ops::cauchy_::call(const_cast(*this), median, sigma, generator); +} + +// aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::log_normal_(double mean, double std, ::std::optional generator) const { + return at::_ops::log_normal_::call(const_cast(*this), mean, std, generator); +} + +// aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::exponential_(double lambd, ::std::optional generator) const { + return at::_ops::exponential_::call(const_cast(*this), lambd, generator); +} + +// aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::geometric_(double p, ::std::optional generator) const { + return at::_ops::geometric_::call(const_cast(*this), p, generator); +} + +// aten::diag(Tensor self, int diagonal=0) -> Tensor +inline at::Tensor Tensor::diag(int64_t diagonal) const { + return at::_ops::diag::call(const_cast(*this), diagonal); +} + +// aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor +inline at::Tensor Tensor::cross(const at::Tensor & other, ::std::optional dim) const { + return at::_ops::cross::call(const_cast(*this), other, dim); +} + +// aten::triu(Tensor self, int diagonal=0) -> Tensor +inline at::Tensor Tensor::triu(int64_t diagonal) const { + return at::_ops::triu::call(const_cast(*this), diagonal); +} + +// aten::tril(Tensor self, int diagonal=0) -> Tensor +inline at::Tensor Tensor::tril(int64_t diagonal) const { + return at::_ops::tril::call(const_cast(*this), diagonal); +} + +// aten::trace(Tensor self) -> Tensor +inline at::Tensor Tensor::trace() const { + return at::_ops::trace::call(const_cast(*this)); +} + +// aten::ne.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::ne(const at::Scalar & other) const { + return at::_ops::ne_Scalar::call(const_cast(*this), other); +} + +// aten::ne.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::ne(const at::Tensor & other) const { + return at::_ops::ne_Tensor::call(const_cast(*this), other); +} + +// aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::ne_(const at::Scalar & other) const { + return at::_ops::ne__Scalar::call(const_cast(*this), other); +} + +// aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::ne_(const at::Tensor & other) const { + return at::_ops::ne__Tensor::call(const_cast(*this), other); +} + +// aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::not_equal(const at::Scalar & other) const { + return at::_ops::not_equal_Scalar::call(const_cast(*this), other); +} + +// aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::not_equal(const at::Tensor & other) const { + return at::_ops::not_equal_Tensor::call(const_cast(*this), other); +} + +// aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::not_equal_(const at::Scalar & other) const { + return at::_ops::not_equal__Scalar::call(const_cast(*this), other); +} + +// aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::not_equal_(const at::Tensor & other) const { + return at::_ops::not_equal__Tensor::call(const_cast(*this), other); +} + +// aten::eq.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::eq(const at::Scalar & other) const { + return at::_ops::eq_Scalar::call(const_cast(*this), other); +} + +// aten::eq.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::eq(const at::Tensor & other) const { + return at::_ops::eq_Tensor::call(const_cast(*this), other); +} + +// aten::ge.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::ge(const at::Scalar & other) const { + return at::_ops::ge_Scalar::call(const_cast(*this), other); +} + +// aten::ge.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::ge(const at::Tensor & other) const { + return at::_ops::ge_Tensor::call(const_cast(*this), other); +} + +// aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::ge_(const at::Scalar & other) const { + return at::_ops::ge__Scalar::call(const_cast(*this), other); +} + +// aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::ge_(const at::Tensor & other) const { + return at::_ops::ge__Tensor::call(const_cast(*this), other); +} + +// aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::greater_equal(const at::Scalar & other) const { + return at::_ops::greater_equal_Scalar::call(const_cast(*this), other); +} + +// aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::greater_equal(const at::Tensor & other) const { + return at::_ops::greater_equal_Tensor::call(const_cast(*this), other); +} + +// aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_equal_(const at::Scalar & other) const { + return at::_ops::greater_equal__Scalar::call(const_cast(*this), other); +} + +// aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_equal_(const at::Tensor & other) const { + return at::_ops::greater_equal__Tensor::call(const_cast(*this), other); +} + +// aten::le.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::le(const at::Scalar & other) const { + return at::_ops::le_Scalar::call(const_cast(*this), other); +} + +// aten::le.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::le(const at::Tensor & other) const { + return at::_ops::le_Tensor::call(const_cast(*this), other); +} + +// aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::le_(const at::Scalar & other) const { + return at::_ops::le__Scalar::call(const_cast(*this), other); +} + +// aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::le_(const at::Tensor & other) const { + return at::_ops::le__Tensor::call(const_cast(*this), other); +} + +// aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::less_equal(const at::Scalar & other) const { + return at::_ops::less_equal_Scalar::call(const_cast(*this), other); +} + +// aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::less_equal(const at::Tensor & other) const { + return at::_ops::less_equal_Tensor::call(const_cast(*this), other); +} + +// aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::less_equal_(const at::Scalar & other) const { + return at::_ops::less_equal__Scalar::call(const_cast(*this), other); +} + +// aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::less_equal_(const at::Tensor & other) const { + return at::_ops::less_equal__Tensor::call(const_cast(*this), other); +} + +// aten::gt.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::gt(const at::Scalar & other) const { + return at::_ops::gt_Scalar::call(const_cast(*this), other); +} + +// aten::gt.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::gt(const at::Tensor & other) const { + return at::_ops::gt_Tensor::call(const_cast(*this), other); +} + +// aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::gt_(const at::Scalar & other) const { + return at::_ops::gt__Scalar::call(const_cast(*this), other); +} + +// aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::gt_(const at::Tensor & other) const { + return at::_ops::gt__Tensor::call(const_cast(*this), other); +} + +// aten::greater.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::greater(const at::Scalar & other) const { + return at::_ops::greater_Scalar::call(const_cast(*this), other); +} + +// aten::greater.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::greater(const at::Tensor & other) const { + return at::_ops::greater_Tensor::call(const_cast(*this), other); +} + +// aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_(const at::Scalar & other) const { + return at::_ops::greater__Scalar::call(const_cast(*this), other); +} + +// aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::greater_(const at::Tensor & other) const { + return at::_ops::greater__Tensor::call(const_cast(*this), other); +} + +// aten::lt.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::lt(const at::Scalar & other) const { + return at::_ops::lt_Scalar::call(const_cast(*this), other); +} + +// aten::lt.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::lt(const at::Tensor & other) const { + return at::_ops::lt_Tensor::call(const_cast(*this), other); +} + +// aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::lt_(const at::Scalar & other) const { + return at::_ops::lt__Scalar::call(const_cast(*this), other); +} + +// aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::lt_(const at::Tensor & other) const { + return at::_ops::lt__Tensor::call(const_cast(*this), other); +} + +// aten::less.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::less(const at::Scalar & other) const { + return at::_ops::less_Scalar::call(const_cast(*this), other); +} + +// aten::less.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::less(const at::Tensor & other) const { + return at::_ops::less_Tensor::call(const_cast(*this), other); +} + +// aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::less_(const at::Scalar & other) const { + return at::_ops::less__Scalar::call(const_cast(*this), other); +} + +// aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::less_(const at::Tensor & other) const { + return at::_ops::less__Tensor::call(const_cast(*this), other); +} + +// aten::take(Tensor self, Tensor index) -> Tensor +inline at::Tensor Tensor::take(const at::Tensor & index) const { + return at::_ops::take::call(const_cast(*this), index); +} + +// aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor +inline at::Tensor Tensor::take_along_dim(const at::Tensor & indices, ::std::optional dim) const { + return at::_ops::take_along_dim::call(const_cast(*this), indices, dim); +} + +// aten::index_select(Tensor self, int dim, Tensor index) -> Tensor +inline at::Tensor Tensor::index_select(int64_t dim, const at::Tensor & index) const { + return at::_ops::index_select::call(const_cast(*this), dim, index); +} + +// aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor +inline at::Tensor Tensor::index_select(at::Dimname dim, const at::Tensor & index) const { + return at::_ops::index_select_dimname::call(const_cast(*this), dim, index); +} + +// aten::masked_select(Tensor self, Tensor mask) -> Tensor +inline at::Tensor Tensor::masked_select(const at::Tensor & mask) const { + return at::_ops::masked_select::call(const_cast(*this), mask); +} + +// aten::nonzero(Tensor self) -> Tensor +inline at::Tensor Tensor::nonzero() const { + return at::_ops::nonzero::call(const_cast(*this)); +} + +// aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor +inline at::Tensor Tensor::nonzero_static(int64_t size, int64_t fill_value) const { + return at::_ops::nonzero_static::call(const_cast(*this), size, fill_value); +} + +// aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor +inline at::Tensor Tensor::nonzero_static_symint(c10::SymInt size, int64_t fill_value) const { + return at::_ops::nonzero_static::call(const_cast(*this), size, fill_value); +} + +// aten::nonzero_numpy(Tensor self) -> Tensor[] +inline ::std::vector Tensor::nonzero_numpy() const { + return at::_ops::nonzero_numpy::call(const_cast(*this)); +} + +// aten::argwhere(Tensor self) -> Tensor +inline at::Tensor Tensor::argwhere() const { + return at::_ops::argwhere::call(const_cast(*this)); +} + +// aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor +inline at::Tensor Tensor::gather(int64_t dim, const at::Tensor & index, bool sparse_grad) const { + return at::_ops::gather::call(const_cast(*this), dim, index, sparse_grad); +} + +// aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor +inline at::Tensor Tensor::gather(at::Dimname dim, const at::Tensor & index, bool sparse_grad) const { + return at::_ops::gather_dimname::call(const_cast(*this), dim, index, sparse_grad); +} + +// aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor Tensor::addcmul(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcmul::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) +inline at::Tensor & Tensor::addcmul_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcmul_::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor Tensor::addcdiv(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcdiv::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) +inline at::Tensor & Tensor::addcdiv_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const { + return at::_ops::addcdiv_::call(const_cast(*this), tensor1, tensor2, value); +} + +// aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) +inline ::std::tuple Tensor::triangular_solve(const at::Tensor & A, bool upper, bool transpose, bool unitriangular) const { + return at::_ops::triangular_solve::call(const_cast(*this), A, upper, transpose, unitriangular); +} + +// aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) +inline ::std::tuple Tensor::svd(bool some, bool compute_uv) const { + return at::_ops::svd::call(const_cast(*this), some, compute_uv); +} + +// aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a) +inline at::Tensor Tensor::swapaxes(int64_t axis0, int64_t axis1) const { + return at::_ops::swapaxes::call(const_cast(*this), axis0, axis1); +} + +// aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!) +inline at::Tensor & Tensor::swapaxes_(int64_t axis0, int64_t axis1) const { + return at::_ops::swapaxes_::call(const_cast(*this), axis0, axis1); +} + +// aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a) +inline at::Tensor Tensor::swapdims(int64_t dim0, int64_t dim1) const { + return at::_ops::swapdims::call(const_cast(*this), dim0, dim1); +} + +// aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) +inline at::Tensor & Tensor::swapdims_(int64_t dim0, int64_t dim1) const { + return at::_ops::swapdims_::call(const_cast(*this), dim0, dim1); +} + +// aten::cholesky(Tensor self, bool upper=False) -> Tensor +inline at::Tensor Tensor::cholesky(bool upper) const { + return at::_ops::cholesky::call(const_cast(*this), upper); +} + +// aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor +inline at::Tensor Tensor::cholesky_solve(const at::Tensor & input2, bool upper) const { + return at::_ops::cholesky_solve::call(const_cast(*this), input2, upper); +} + +// aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor +inline at::Tensor Tensor::cholesky_inverse(bool upper) const { + return at::_ops::cholesky_inverse::call(const_cast(*this), upper); +} + +// aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) +inline ::std::tuple Tensor::qr(bool some) const { + return at::_ops::qr::call(const_cast(*this), some); +} + +// aten::geqrf(Tensor self) -> (Tensor a, Tensor tau) +inline ::std::tuple Tensor::geqrf() const { + return at::_ops::geqrf::call(const_cast(*this)); +} + +// aten::orgqr(Tensor self, Tensor input2) -> Tensor +inline at::Tensor Tensor::orgqr(const at::Tensor & input2) const { + return at::_ops::orgqr::call(const_cast(*this), input2); +} + +// aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor +inline at::Tensor Tensor::ormqr(const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose) const { + return at::_ops::ormqr::call(const_cast(*this), input2, input3, left, transpose); +} + +// aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor +inline at::Tensor Tensor::lu_solve(const at::Tensor & LU_data, const at::Tensor & LU_pivots) const { + return at::_ops::lu_solve::call(const_cast(*this), LU_data, LU_pivots); +} + +// aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::multinomial(int64_t num_samples, bool replacement, ::std::optional generator) const { + return at::_ops::multinomial::call(const_cast(*this), num_samples, replacement, generator); +} + +// aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor +inline at::Tensor Tensor::multinomial_symint(c10::SymInt num_samples, bool replacement, ::std::optional generator) const { + return at::_ops::multinomial::call(const_cast(*this), num_samples, replacement, generator); +} + +// aten::lgamma_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::lgamma_() const { + return at::_ops::lgamma_::call(const_cast(*this)); +} + +// aten::lgamma(Tensor self) -> Tensor +inline at::Tensor Tensor::lgamma() const { + return at::_ops::lgamma::call(const_cast(*this)); +} + +// aten::digamma(Tensor self) -> Tensor +inline at::Tensor Tensor::digamma() const { + return at::_ops::digamma::call(const_cast(*this)); +} + +// aten::polygamma(int n, Tensor self) -> Tensor +inline at::Tensor Tensor::polygamma(int64_t n) const { + return at::_ops::polygamma::call(n, const_cast(*this)); +} + +// aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!) +inline at::Tensor & Tensor::polygamma_(int64_t n) const { + return at::_ops::polygamma_::call(const_cast(*this), n); +} + +// aten::erfinv(Tensor self) -> Tensor +inline at::Tensor Tensor::erfinv() const { + return at::_ops::erfinv::call(const_cast(*this)); +} + +// aten::erfinv_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::erfinv_() const { + return at::_ops::erfinv_::call(const_cast(*this)); +} + +// aten::i0(Tensor self) -> Tensor +inline at::Tensor Tensor::i0() const { + return at::_ops::i0::call(const_cast(*this)); +} + +// aten::i0_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::i0_() const { + return at::_ops::i0_::call(const_cast(*this)); +} + +// aten::sign(Tensor self) -> Tensor +inline at::Tensor Tensor::sign() const { + return at::_ops::sign::call(const_cast(*this)); +} + +// aten::sign_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & Tensor::sign_() const { + return at::_ops::sign_::call(const_cast(*this)); +} + +// aten::signbit(Tensor self) -> Tensor +inline at::Tensor Tensor::signbit() const { + return at::_ops::signbit::call(const_cast(*this)); +} + +// aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor +inline at::Tensor Tensor::dist(const at::Tensor & other, const at::Scalar & p) const { + return at::_ops::dist::call(const_cast(*this), other, p); +} + +// aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::atan2_(const at::Tensor & other) const { + return at::_ops::atan2_::call(const_cast(*this), other); +} + +// aten::atan2(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::atan2(const at::Tensor & other) const { + return at::_ops::atan2::call(const_cast(*this), other); +} + +// aten::arctan2(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::arctan2(const at::Tensor & other) const { + return at::_ops::arctan2::call(const_cast(*this), other); +} + +// aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::arctan2_(const at::Tensor & other) const { + return at::_ops::arctan2_::call(const_cast(*this), other); +} + +// aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor +inline at::Tensor Tensor::lerp(const at::Tensor & end, const at::Scalar & weight) const { + return at::_ops::lerp_Scalar::call(const_cast(*this), end, weight); +} + +// aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor +inline at::Tensor Tensor::lerp(const at::Tensor & end, const at::Tensor & weight) const { + return at::_ops::lerp_Tensor::call(const_cast(*this), end, weight); +} + +// aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor +inline at::Tensor Tensor::histc(int64_t bins, const at::Scalar & min, const at::Scalar & max) const { + return at::_ops::histc::call(const_cast(*this), bins, min, max); +} + +// aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) +inline ::std::tuple Tensor::histogram(const at::Tensor & bins, const ::std::optional & weight, bool density) const { + return at::_ops::histogram_bins_tensor::call(const_cast(*this), bins, weight, density); +} + +// aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) +inline ::std::tuple Tensor::histogram(int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density) const { + return at::_ops::histogram_bin_ct::call(const_cast(*this), bins, range, weight, density); +} + +// aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::fmod(const at::Scalar & other) const { + return at::_ops::fmod_Scalar::call(const_cast(*this), other); +} + +// aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::fmod_(const at::Scalar & other) const { + return at::_ops::fmod__Scalar::call(const_cast(*this), other); +} + +// aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::fmod(const at::Tensor & other) const { + return at::_ops::fmod_Tensor::call(const_cast(*this), other); +} + +// aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::fmod_(const at::Tensor & other) const { + return at::_ops::fmod__Tensor::call(const_cast(*this), other); +} + +// aten::hypot(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::hypot(const at::Tensor & other) const { + return at::_ops::hypot::call(const_cast(*this), other); +} + +// aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::hypot_(const at::Tensor & other) const { + return at::_ops::hypot_::call(const_cast(*this), other); +} + +// aten::igamma(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::igamma(const at::Tensor & other) const { + return at::_ops::igamma::call(const_cast(*this), other); +} + +// aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::igamma_(const at::Tensor & other) const { + return at::_ops::igamma_::call(const_cast(*this), other); +} + +// aten::igammac(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::igammac(const at::Tensor & other) const { + return at::_ops::igammac::call(const_cast(*this), other); +} + +// aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::igammac_(const at::Tensor & other) const { + return at::_ops::igammac_::call(const_cast(*this), other); +} + +// aten::nextafter(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::nextafter(const at::Tensor & other) const { + return at::_ops::nextafter::call(const_cast(*this), other); +} + +// aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::nextafter_(const at::Tensor & other) const { + return at::_ops::nextafter_::call(const_cast(*this), other); +} + +// aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor +inline at::Tensor Tensor::remainder(const at::Scalar & other) const { + return at::_ops::remainder_Scalar::call(const_cast(*this), other); +} + +// aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) +inline at::Tensor & Tensor::remainder_(const at::Scalar & other) const { + return at::_ops::remainder__Scalar::call(const_cast(*this), other); +} + +// aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::remainder(const at::Tensor & other) const { + return at::_ops::remainder_Tensor::call(const_cast(*this), other); +} + +// aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) +inline at::Tensor & Tensor::remainder_(const at::Tensor & other) const { + return at::_ops::remainder__Tensor::call(const_cast(*this), other); +} + +// aten::min(Tensor self) -> Tensor +inline at::Tensor Tensor::min() const { + return at::_ops::min::call(const_cast(*this)); +} + +// aten::fmin(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::fmin(const at::Tensor & other) const { + return at::_ops::fmin::call(const_cast(*this), other); +} + +// aten::max(Tensor self) -> Tensor +inline at::Tensor Tensor::max() const { + return at::_ops::max::call(const_cast(*this)); +} + +// aten::fmax(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::fmax(const at::Tensor & other) const { + return at::_ops::fmax::call(const_cast(*this), other); +} + +// aten::maximum(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::maximum(const at::Tensor & other) const { + return at::_ops::maximum::call(const_cast(*this), other); +} + +// aten::max.other(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::max(const at::Tensor & other) const { + return at::_ops::max_other::call(const_cast(*this), other); +} + +// aten::minimum(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::minimum(const at::Tensor & other) const { + return at::_ops::minimum::call(const_cast(*this), other); +} + +// aten::min.other(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::min(const at::Tensor & other) const { + return at::_ops::min_other::call(const_cast(*this), other); +} + +// aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::quantile(const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::quantile::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::quantile(double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::quantile_scalar::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::nanquantile(const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::nanquantile::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor +inline at::Tensor Tensor::nanquantile(double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) const { + return at::_ops::nanquantile_scalar::call(const_cast(*this), q, dim, keepdim, interpolation); +} + +// aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(int64_t dim, bool descending) const { + return at::_ops::sort::call(const_cast(*this), dim, descending); +} + +// aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(::std::optional stable, int64_t dim, bool descending) const { + return at::_ops::sort_stable::call(const_cast(*this), stable, dim, descending); +} + +// aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(at::Dimname dim, bool descending) const { + return at::_ops::sort_dimname::call(const_cast(*this), dim, descending); +} + +// aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::sort(::std::optional stable, at::Dimname dim, bool descending) const { + return at::_ops::sort_dimname_stable::call(const_cast(*this), stable, dim, descending); +} + +// aten::msort(Tensor self) -> Tensor +inline at::Tensor Tensor::msort() const { + return at::_ops::msort::call(const_cast(*this)); +} + +// aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor +inline at::Tensor Tensor::argsort(int64_t dim, bool descending) const { + return at::_ops::argsort::call(const_cast(*this), dim, descending); +} + +// aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor +inline at::Tensor Tensor::argsort(bool stable, int64_t dim, bool descending) const { + return at::_ops::argsort_stable::call(const_cast(*this), stable, dim, descending); +} + +// aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor +inline at::Tensor Tensor::argsort(at::Dimname dim, bool descending) const { + return at::_ops::argsort_dimname::call(const_cast(*this), dim, descending); +} + +// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const { + return at::_ops::topk::call(const_cast(*this), k, dim, largest, sorted); +} + +// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) +inline ::std::tuple Tensor::topk_symint(c10::SymInt k, int64_t dim, bool largest, bool sorted) const { + return at::_ops::topk::call(const_cast(*this), k, dim, largest, sorted); +} + +// aten::all(Tensor self) -> Tensor +inline at::Tensor Tensor::all() const { + return at::_ops::all::call(const_cast(*this)); +} + +// aten::any(Tensor self) -> Tensor +inline at::Tensor Tensor::any() const { + return at::_ops::any::call(const_cast(*this)); +} + +// aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor +inline at::Tensor Tensor::renorm(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const { + return at::_ops::renorm::call(const_cast(*this), p, dim, maxnorm); +} + +// aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) +inline at::Tensor & Tensor::renorm_(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const { + return at::_ops::renorm_::call(const_cast(*this), p, dim, maxnorm); +} + +// aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) +inline at::Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) const { + return at::_ops::unfold::call(const_cast(*this), dimension, size, step); +} + +// aten::equal(Tensor self, Tensor other) -> bool +inline bool Tensor::equal(const at::Tensor & other) const { + return at::_ops::equal::call(const_cast(*this), other); +} + +// aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor +inline at::Tensor Tensor::pow(const at::Tensor & exponent) const { + return at::_ops::pow_Tensor_Tensor::call(const_cast(*this), exponent); +} + +// aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor +inline at::Tensor Tensor::pow(const at::Scalar & exponent) const { + return at::_ops::pow_Tensor_Scalar::call(const_cast(*this), exponent); +} + +// aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) +inline at::Tensor & Tensor::pow_(const at::Scalar & exponent) const { + return at::_ops::pow__Scalar::call(const_cast(*this), exponent); +} + +// aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) +inline at::Tensor & Tensor::pow_(const at::Tensor & exponent) const { + return at::_ops::pow__Tensor::call(const_cast(*this), exponent); +} + +// aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor +inline at::Tensor Tensor::float_power(const at::Tensor & exponent) const { + return at::_ops::float_power_Tensor_Tensor::call(const_cast(*this), exponent); +} + +// aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor +inline at::Tensor Tensor::float_power(const at::Scalar & exponent) const { + return at::_ops::float_power_Tensor_Scalar::call(const_cast(*this), exponent); +} + +// aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) +inline at::Tensor & Tensor::float_power_(const at::Scalar & exponent) const { + return at::_ops::float_power__Scalar::call(const_cast(*this), exponent); +} + +// aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) +inline at::Tensor & Tensor::float_power_(const at::Tensor & exponent) const { + return at::_ops::float_power__Tensor::call(const_cast(*this), exponent); +} + +// aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) +inline at::Tensor & Tensor::normal_(double mean, double std, ::std::optional generator) const { + return at::_ops::normal_::call(const_cast(*this), mean, std, generator); +} + +// aten::alias(Tensor(a) self) -> Tensor(a) +inline at::Tensor Tensor::alias() const { + return at::_ops::alias::call(const_cast(*this)); +} + +// aten::isfinite(Tensor self) -> Tensor +inline at::Tensor Tensor::isfinite() const { + return at::_ops::isfinite::call(const_cast(*this)); +} + +// aten::isinf(Tensor self) -> Tensor +inline at::Tensor Tensor::isinf() const { + return at::_ops::isinf::call(const_cast(*this)); +} + +// aten::record_stream(Tensor(a!) self, Stream s) -> () +inline void Tensor::record_stream(at::Stream s) const { + return at::_ops::record_stream::call(const_cast(*this), s); +} + +// aten::isposinf(Tensor self) -> Tensor +inline at::Tensor Tensor::isposinf() const { + return at::_ops::isposinf::call(const_cast(*this)); +} + +// aten::isneginf(Tensor self) -> Tensor +inline at::Tensor Tensor::isneginf() const { + return at::_ops::isneginf::call(const_cast(*this)); +} + +// aten::det(Tensor self) -> Tensor +inline at::Tensor Tensor::det() const { + return at::_ops::det::call(const_cast(*this)); +} + +// aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) +inline ::std::tuple Tensor::slogdet() const { + return at::_ops::slogdet::call(const_cast(*this)); +} + +// aten::logdet(Tensor self) -> Tensor +inline at::Tensor Tensor::logdet() const { + return at::_ops::logdet::call(const_cast(*this)); +} + +// aten::inverse(Tensor self) -> Tensor +inline at::Tensor Tensor::inverse() const { + return at::_ops::inverse::call(const_cast(*this)); +} + +// aten::inner(Tensor self, Tensor other) -> Tensor +inline at::Tensor Tensor::inner(const at::Tensor & other) const { + return at::_ops::inner::call(const_cast(*this), other); +} + +// aten::outer(Tensor self, Tensor vec2) -> Tensor +inline at::Tensor Tensor::outer(const at::Tensor & vec2) const { + return at::_ops::outer::call(const_cast(*this), vec2); +} + +// aten::ger(Tensor self, Tensor vec2) -> Tensor +inline at::Tensor Tensor::ger(const at::Tensor & vec2) const { + return at::_ops::ger::call(const_cast(*this), vec2); +} + +// aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor +inline at::Tensor Tensor::to_padded_tensor(double padding, at::OptionalIntArrayRef output_size) const { + return at::_ops::to_padded_tensor::call(const_cast(*this), padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt); +} + +// aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor +inline at::Tensor Tensor::to_padded_tensor_symint(double padding, at::OptionalSymIntArrayRef output_size) const { + return at::_ops::to_padded_tensor::call(const_cast(*this), padding, output_size); +} +} // namespace at + + +namespace c10 { +template <> +struct MaybeOwnedTraits { + using owned_type = at::Tensor; + using borrow_type = at::Tensor; + + static borrow_type createBorrow(const owned_type& from) { + // NOTE: this can be implemented without the special + // unsafe_borrow_t Tensor constructor as + // + // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl())); + // + // but that hurts inlining due to the nullptr check in the + // Tensor(c10::intrusive_ptr<...>) constructor. We already know + // that from.impl_ isn't null because from is a valid Tensor, so + // we needn't do the check again. (using __builtin_assume can + // avoid this, but wouldn't be portable to MSVC.) + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseTensorImpl(); + // See above note: this can be implemented with public API + // similarly to createBorrow(), but that would hurt inlining. + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits { + using repr_type = at::Tensor; + using pointer_type = at::Tensor*; + using const_pointer_type = const at::Tensor*; + + static repr_type nullRepr() { + return at::Tensor(); + } + + template + static repr_type createInPlace(Args&&... args) { + return at::Tensor(std::forward(args)...); + } + + static repr_type moveToRepr(at::Tensor&& x) { + return std::move(x); + } + + static void destroyOwned(at::Tensor& x) { + return ExclusivelyOwnedTraits::destroyOwned(x); + } + + static at::Tensor take(at::Tensor& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 + +namespace at { + +inline c10::MaybeOwned borrow_from_optional_tensor( + const std::optional& opt) { + return opt.has_value() + ? c10::MaybeOwned::borrowed(*opt) + : c10::MaybeOwned::owned(std::in_place); +} + +inline c10::MaybeOwned Tensor::expect_contiguous(MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); + } +} +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TorchDispatchUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TorchDispatchUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7cb9934102852c3788ebdefa4b1738151391cf14 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TorchDispatchUtils.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::impl { + +TORCH_API bool tensor_has_dispatch(const at::Tensor& t); +TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li); +TORCH_API bool tensorlist_has_dispatch( + const c10::List>& li); +using c10::impl::dispatch_mode_enabled; + +} // namespace at::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TransformationHelper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TransformationHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..f81018a8e674f21027ef665c69a2788affe4f15a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/TransformationHelper.h @@ -0,0 +1,175 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// Using DistAccumType in accumulate types for distributions. +// Note: Ideally we'd be using ATen/AccumulateType.h but looks +// like the there is some inconsistency in how accumulate types +// are mapped currently, e.g. for the cpu side, float is mapped +// to double. +template +struct DistAccumType { }; + +#if defined(__CUDACC__) || defined(__HIPCC__) +template <> struct DistAccumType { using type = float; }; +#endif +template <> struct DistAccumType { using type = float; }; +template <> struct DistAccumType { using type = float; }; +template <> struct DistAccumType { using type = float; }; +template <> struct DistAccumType { using type = double; }; + +template +using dist_acctype = typename DistAccumType::type; + +namespace transformation { + +/** + * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified. + * `range` is `to - from` + * `base` is `from` + */ +template +C10_HOST_DEVICE inline T uniform_int_from_to(V val, uint64_t range, int64_t base) { + return static_cast(static_cast((val % range) + base)); +} + +/** + * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None + */ +template +C10_HOST_DEVICE inline T uniform_int_full_range(V val) { + return static_cast(static_cast(val)); +} + +/** + * A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`. + * In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double + * in this overloaded version + */ +template +C10_HOST_DEVICE inline std::enable_if_t), T>uniform_int(V val) { + if constexpr (std::is_same_v) { + return static_cast(val & 1); + } else if constexpr (std::is_same_v) { + return static_cast(val % (static_cast(std::numeric_limits::max()) + 1)); + } else if constexpr (std::is_same_v || std::is_same_v) { + return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1)); + } else if constexpr (std::is_integral_v) { + return static_cast(val % (static_cast(std::numeric_limits::max()) + 1)); + } else { + assert(false); + return 0; + } +} + +/** + * An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`, + * added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version. + */ +template +C10_HOST_DEVICE inline std::enable_if_t, T>uniform_int(V val) { + return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1)); +} + +template +C10_HOST_DEVICE inline dist_acctype uniform_real(V val, T from, T to) { + constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); + constexpr auto DIVISOR = static_cast>(1) / (static_cast(1) << std::numeric_limits::digits); + dist_acctype x = (val & MASK) * DIVISOR; + return (x * (to - from) + from); +} + +/** + * Transforms normally distributed `val` with mean 0.0 and standard deviation 1.0 to + * normally distributed with `mean` and standard deviation `std`. + */ +template +C10_HOST_DEVICE inline T normal(T val, T mean, T std) { + return val * std + mean; +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * Cauchy distribution with location parameter `median` and scale parameter `sigma`. + */ +template +C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) { + // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function + // __tanf overflows and returns `inf/-inf` when (val > 1 - eps) or (val < 0 + eps), + // thus we clip those values. + constexpr T eps = std::numeric_limits::epsilon(); + constexpr T one_minus_eps = 1 - eps; + constexpr T zero_plus_eps = 0 + eps; + val = (val > one_minus_eps ? one_minus_eps : val); + val = (val < zero_plus_eps ? zero_plus_eps : val); + return median + sigma * at::tan(c10::pi * (val - static_cast(0.5))); +} + +template <> +C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) { + // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function + return median + sigma * at::tan(c10::pi * (val - static_cast(0.5))); +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * exponentially distributed with `lambda` parameter of the distribution. + */ +template +C10_HOST_DEVICE inline T exponential(T val, T lambda) { + // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates + // Different implementations for CUDA and CPU to preserve original logic + // TODO: must be investigated and unified!!! + // https://github.com/pytorch/pytorch/issues/38662 +#if defined(__CUDACC__) || defined(__HIPCC__) + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706 + // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. + // we need log to be not 0, and not underflow when converted to half + // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args + auto log = val >= static_cast(1.) - std::numeric_limits::epsilon() / 2 + ? -std::numeric_limits::epsilon() / 2 + : at::log(val); + return static_cast(-1.0) / lambda * log; +#else + return static_cast(-1.0) / lambda * at::log1p(-val); +#endif +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * geometrically distributed with success probability `p`. + */ +template +C10_HOST_DEVICE inline T geometric(T val, T p) { + // https://en.wikipedia.org/wiki/Geometric_distribution#Related_distributions + return static_cast(::ceil(at::log(val) / at::log1p(-p))); +} + +/** + * Transforms normally distributed `val` to log-normally distributed. + */ +template +C10_HOST_DEVICE inline T log_normal(T val) { + // https://en.wikipedia.org/wiki/Log-normal_distribution#Mode,_median,_quantiles + return at::exp(val); +} + +/** + * Transforms uniformly distributed `val` between 0.0 and 1.0 to + * bernoulli distributed with success probability `p`. + */ +template +C10_HOST_DEVICE inline T bernoulli(T val, T p) { + return val < p; +} + +}} // namespace at::transformation diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..885f6e195f05d37ab4253315242167f8e546dcc1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h @@ -0,0 +1 @@ +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UnsafeFromTH.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UnsafeFromTH.h new file mode 100644 index 0000000000000000000000000000000000000000..9ad5c45d3ab6ca499bc10b71c68bc04ededfeb87 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/UnsafeFromTH.h @@ -0,0 +1,21 @@ +#pragma once +#include + +namespace at { + +inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) { + auto tensor_impl = c10::intrusive_ptr::reclaim(static_cast(th_pointer)); + if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) { + c10::raw::intrusive_ptr::incref(tensor_impl.get()); + } + return Tensor(std::move(tensor_impl)); +} + +inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) { + if (retain && th_pointer) { + c10::raw::intrusive_ptr::incref(static_cast(th_pointer)); + } + return Storage(c10::intrusive_ptr::reclaim(static_cast(th_pointer))); +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/VariableHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/VariableHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..f9c0aa4a5fc1482bff955cda5768c9dad41031d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/VariableHooksInterface.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include + +// A little explanation about why this file exists at all. We have +// a few methods on Tensor class which require access to reified access to +// AutogradMeta. In open source, this isn't a big deal: we just access +// torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and +// we can put the definitions inline. This is because everything gets balled +// into a single dynamic library in the end. +// +// However, inside our Facebook internal version of our build system, we +// have a split between aten and torch/csrc. So we cannot simply just +// cross this boundary. "Now wait," you might say, "Why don't we just +// merge the libraries inside Facebook". Well, the problem is that there +// are some downstream applications which are at binary size limit, and +// incorporating all of the extra code from libtorch would push them +// over (admarket/adreview/service:adreviewservice, see also +// https://github.com/pytorch/pytorch/pull/29299) So if you want to do that, +// we have to fix all of the services like this. +// +// I didn't want to block eliminating Tensor-Variable on this work, so I +// had to introduce another dynamic dispatch to get to the variable +// implementations (which live in torch/csrc/autograd/variable.cpp, FYI). +// +// I also considered using our existing dynamic dispatch mechanism, c10 +// dispatcher, to do this. However, (1) some of the functions on Tensor +// have weird signatures that are not supported by autograd, and (2) +// see this bug https://github.com/pytorch/pytorch/issues/30102 + +namespace torch::autograd { + +struct Node; + +} // namespace torch::autograd + +namespace at::impl { + +struct TORCH_API VariableHooksInterface { + virtual ~VariableHooksInterface() = default; + virtual TensorBase tensor_data(const TensorBase&) const = 0; + virtual TensorBase variable_data(const TensorBase&) const = 0; + virtual const std::shared_ptr& grad_fn( + const TensorBase&) const = 0; + virtual unsigned _register_hook( + const TensorBase&, + std::function hook) const = 0; + virtual void remove_hook(const TensorBase&, unsigned pos) const = 0; + virtual bool is_view(const TensorBase&) const = 0; + virtual const TensorBase& base(const TensorBase&) const = 0; + virtual const std::string& name(const TensorBase&) const = 0; + virtual bool is_leaf(const TensorBase&) const = 0; + virtual int64_t output_nr(const TensorBase&) const = 0; + virtual void set_data(const TensorBase&, const TensorBase&) const = 0; + virtual TensorBase data(const TensorBase&) const = 0; + virtual int64_t _version(const TensorBase&) const = 0; + virtual void retain_grad(const TensorBase&) const = 0; + virtual bool retains_grad(const TensorBase&) const = 0; + virtual void _backward( + const Tensor&, + TensorList, + const std::optional&, + std::optional, + bool) const = 0; + virtual void requires_grad_(const TensorBase&, bool) const = 0; + virtual void basic_autograd_not_implemented_fallback( + const c10::OperatorHandle& op, + c10::DispatchKeySet dispatch_keys, + torch::jit::Stack* stack) const = 0; +}; + +TORCH_API void SetVariableHooks(VariableHooksInterface* hooks); +TORCH_API VariableHooksInterface* GetVariableHooks(); +TORCH_API bool HasVariableHooks(); + +struct TORCH_API VariableHooksRegisterer { + explicit VariableHooksRegisterer(VariableHooksInterface* hooks) { + SetVariableHooks(hooks); + } +}; + +} // namespace at::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Variadic.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Variadic.h new file mode 100644 index 0000000000000000000000000000000000000000..da4df1b1b1a6628f76852d1012c7451fbdd85c3e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Variadic.h @@ -0,0 +1,92 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +// This class allows you to write variadic functions which +// call a (possibly overloaded) function on each argument, +// in order. This is most commonly used in autogenerated code, +// where it is convenient to have a function that can uniformly +// take arguments of different types. If your arguments +// are homogenous consider using a std::initializer_list instead. +// +// For examples of this in use, see torch/csrc/utils/variadic.h +template +struct IterArgs { + template + inline F& apply() { + return self(); + } + + // NB: Use perfect forwarding here, otherwise we'll make value + // copies of all arguments! + template + inline F& apply(T&& arg, Args&&... args) { + self()(std::forward(arg)); + if (self().short_circuit()) { + return self(); + } else { + return apply(std::forward(args)...); + } + } + + // Here are some handy overloads which provide sensible + // defaults for container-like structures that one might + // be interested in recursing into. You can enable them + // by adding: + // + // using IterArgs::operator() + // + // to your struct. These are not enabled by default because + // you may be able to process these structures more efficiently + // than handling them one-by-one. + + template + void operator()(c10::IListRef args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + + template + void operator()(at::ArrayRef args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + + template + void operator()(const torch::List& args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + + // NB: we need to specify std::vector manually as C++ won't + // do an implicit conversion to make a template deduction go through. + template + void operator()(const std::vector& args) { + self()(at::ArrayRef{args}); + } + + constexpr bool short_circuit() const { + return false; + } + + private: + inline F& self() { + return *static_cast(this); + } +}; + +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Vitals.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Vitals.h new file mode 100644 index 0000000000000000000000000000000000000000..2fd7729744a10398ccc55b79cf02a2e823b1496f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/Vitals.h @@ -0,0 +1,94 @@ +#pragma once +#include +#include +#include + +#include + +namespace at::vitals { + +TORCH_API bool torchVitalEnabled(); + +struct TORCH_API TorchVitalAttr { + // always initialized to empty + std::string value; + template + TorchVitalAttr& operator<<(const T& t) { + if (torchVitalEnabled()) { + std::stringstream ss; + ss << t; + value += ss.str(); + } + return *this; + } + + template + void write(const T& t, bool force) { + if (force || torchVitalEnabled()) { + std::stringstream ss; + ss << t; + value = ss.str(); + } + } +}; + +struct TORCH_API TorchVital { + std::string name; + std::unordered_map attrs; + + explicit TorchVital(std::string n) : name(std::move(n)) {} + TorchVital(const TorchVital&) = default; + TorchVital(TorchVital&&) = default; + TorchVital& operator=(const TorchVital&) = default; + TorchVital& operator=(TorchVital&&) = default; + TorchVital() = delete; + + TorchVitalAttr& create(const std::string& attr); + TorchVitalAttr& create(const std::string& attr, bool force); + friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt); + + ~TorchVital(); +}; + +std::ostream& operator<<(std::ostream& os, TorchVital const& tv); + +// A way to access vitals by string names instead of by global reference. +// This enables access to vitals from the PythonAPI. +class TORCH_API APIVitals { + public: + bool vitals_enabled; + + // Set any vital sign that was added to the map. + bool setVital( + const std::string& vital_name, + const std::string& attr_name, + const std::string& value, + bool force = false); + std::string readVitals(); + + APIVitals(); + + // Ensure this stays a singleton + APIVitals(APIVitals const& other) = delete; + APIVitals(APIVitals&& other) = delete; + APIVitals& operator=(const APIVitals&) = delete; + APIVitals& operator=(APIVitals&&) = delete; + ~APIVitals() = default; + + private: + std::unordered_map name_map_; +}; + +extern TORCH_API APIVitals VitalsAPI; + +} // namespace at::vitals + +#define TORCH_VITAL_DECLARE(name) \ + TORCH_API at::vitals::TorchVital TorchVital_##name; + +#define TORCH_VITAL_DEFINE(name) \ + TORCH_API at::vitals::TorchVital TorchVital_##name(#name); + +#define TORCH_VITAL_BASE(name) TorchVital_##name + +#define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr) diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/alias_info.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/alias_info.h new file mode 100644 index 0000000000000000000000000000000000000000..bf0ff6ee72d3b29d8123f8204cee6a00dd82d11f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/alias_info.h @@ -0,0 +1,162 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +/** + * class AliasInfo + * + * Data structure to hold aliasing information for an `Argument`. They can be + * nested to represent aliasing information on contained types. + * + * There is a `beforeSet` which describes the aliasing information before the + * operator executes, and an `afterSet` that describes aliasing info + * after execution. + */ +class AliasInfo { + public: + AliasInfo() = default; + AliasInfo(bool is_write, const std::set& before_qual_strings, const std::set& after_qual_strings) : isWrite_(is_write) { + for (const auto& s: before_qual_strings) { + beforeSets_.insert(Symbol::fromQualString(s)); + } + for (const auto& s : after_qual_strings) { + afterSets_.insert(Symbol::fromQualString(s)); + } + } + // Symbol for the set that can alias anything + static Symbol wildcardSet() { + static const Symbol wc = Symbol::fromQualString("alias::*"); + return wc; + } + + void setIsWrite(bool isWrite) { + isWrite_ = isWrite; + } + + bool isWrite() const { + return isWrite_; + } + + void addBeforeSet(Symbol aliasSet) { + beforeSets_.insert(aliasSet); + } + + void addAfterSet(Symbol aliasSet) { + afterSets_.insert(aliasSet); + } + + const std::unordered_set& beforeSets() const { + return beforeSets_; + } + + const std::unordered_set& afterSets() const { + return afterSets_; + } + + Symbol beforeSet() const { + AT_ASSERT(beforeSets_.size() == 1); + return *beforeSets_.begin(); + } + + bool isWildcardBefore() const { + return beforeSets_.count(wildcardSet()) != 0; + } + + bool isWildcardAfter() const { + return afterSets_.count(wildcardSet()) != 0; + } + + // the alias info for the contained types of the type + // e.g. if this is an annotation on List[T], `sets` refers to + // the alias sets that the list may be in + // while containedTypes()[0] refers to the sets that members of the list + // may be in + void addContainedType(AliasInfo aliasInfo) { + containedTypes_.push_back(std::move(aliasInfo)); + } + const std::vector& containedTypes() const { + return containedTypes_; + } + + private: + std::unordered_set beforeSets_; + std::unordered_set afterSets_; + std::vector containedTypes_; + bool isWrite_ = false; +}; + +inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { + return lhs.isWrite() == rhs.isWrite() + && lhs.beforeSets() == rhs.beforeSets() + && lhs.afterSets() == rhs.afterSets() + && lhs.containedTypes() == rhs.containedTypes(); +} + +// this does match the way things are represented in the schema +inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { + out << "("; + bool first = true; + for (const auto& set : aliasInfo.beforeSets()) { + if (first) { + first = false; + } else { + out << "|"; + } + out << set.toUnqualString(); + } + if (aliasInfo.isWrite()) { + out << "!"; + } + if (aliasInfo.beforeSets() != aliasInfo.afterSets()) { + out << " -> "; + first = true; + for (const auto& set : aliasInfo.afterSets()) { + if (first) { + first = false; + } else { + out << "|"; + } + out << set.toUnqualString(); + } + } + out << ")"; + return out; +} +} // namespace c10 + +namespace std { +template <> + struct hash { + size_t operator()(const c10::AliasInfo& aliasInfo) const { + auto hash = std::hash()(aliasInfo.isWrite()); + + // NOTE: for unordered_set hashes, we couldn't use hash_combine + // because hash_combine is order dependent. Instead, we choose to + // use XOR as the combining function as XOR is commutative. + size_t before_set_hash_seed = 0; + for (auto &e: aliasInfo.beforeSets()) { + auto symbol_hash = std::hash()(e); + before_set_hash_seed = before_set_hash_seed ^ symbol_hash; + } + size_t after_set_hash_seed = 0; + for (auto &e: aliasInfo.afterSets()) { + auto symbol_hash = std::hash()(e); + after_set_hash_seed = after_set_hash_seed ^ symbol_hash; + } + + hash = c10::hash_combine(hash, before_set_hash_seed); + hash = c10::hash_combine(hash, after_set_hash_seed); + for (auto &e: aliasInfo.containedTypes()) { + auto contained_type_hash = std::hash()(e); + hash = c10::hash_combine(hash, contained_type_hash); + } + return hash; + } + }; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/aten_interned_strings.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/aten_interned_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..62035bd5ab351a437a275d9f8f7bd914231be896 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/aten_interned_strings.h @@ -0,0 +1,2294 @@ +#pragma once + +// @generated by torchgen/gen.py from aten_interned_strings.h + +#if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if including for \ + the c10::Symbol class would be sufficient, or if your change would be \ + better placed in another file. +#endif + +// ATen symbols correspond exactly to operators defined in ATen. Every +// symbol here corresponds exactly to an ATen operation defined in +// native_functions.yaml; attributes are in one-to-one correspondence +// with their ATen name. + +#define FORALL_ATEN_BASE_SYMBOLS(_) \ +_(aten, __and__) \ +_(aten, __iand__) \ +_(aten, __ilshift__) \ +_(aten, __ior__) \ +_(aten, __irshift__) \ +_(aten, __ixor__) \ +_(aten, __lshift__) \ +_(aten, __or__) \ +_(aten, __rshift__) \ +_(aten, __xor__) \ +_(aten, _adaptive_avg_pool2d) \ +_(aten, _adaptive_avg_pool2d_backward) \ +_(aten, _adaptive_avg_pool3d) \ +_(aten, _adaptive_avg_pool3d_backward) \ +_(aten, _add_batch_dim) \ +_(aten, _add_relu) \ +_(aten, _add_relu_) \ +_(aten, _addmm_activation) \ +_(aten, _aminmax) \ +_(aten, _amp_foreach_non_finite_check_and_unscale) \ +_(aten, _amp_foreach_non_finite_check_and_unscale_) \ +_(aten, _amp_update_scale) \ +_(aten, _amp_update_scale_) \ +_(aten, _assert_async) \ +_(aten, _assert_scalar) \ +_(aten, _assert_tensor_metadata) \ +_(aten, _autocast_to_full_precision) \ +_(aten, _autocast_to_reduced_precision) \ +_(aten, _backward) \ +_(aten, _batch_norm_impl_index) \ +_(aten, _batch_norm_impl_index_backward) \ +_(aten, _batch_norm_no_update) \ +_(aten, _batch_norm_with_update) \ +_(aten, _batch_norm_with_update_functional) \ +_(aten, _cast_Byte) \ +_(aten, _cast_Char) \ +_(aten, _cast_Double) \ +_(aten, _cast_Float) \ +_(aten, _cast_Half) \ +_(aten, _cast_Int) \ +_(aten, _cast_Long) \ +_(aten, _cast_Short) \ +_(aten, _cdist_backward) \ +_(aten, _cdist_forward) \ +_(aten, _cholesky_solve_helper) \ +_(aten, _choose_qparams_per_tensor) \ +_(aten, _chunk_cat) \ +_(aten, _coalesce) \ +_(aten, _coalesced) \ +_(aten, _coalesced_) \ +_(aten, _compute_linear_combination) \ +_(aten, _conj) \ +_(aten, _conj_copy) \ +_(aten, _conj_physical) \ +_(aten, _conv_depthwise2d) \ +_(aten, _convert_indices_from_coo_to_csr) \ +_(aten, _convert_indices_from_csr_to_coo) \ +_(aten, _convert_weight_to_int4pack) \ +_(aten, _convert_weight_to_int4pack_for_cpu) \ +_(aten, _convolution) \ +_(aten, _convolution_double_backward) \ +_(aten, _convolution_mode) \ +_(aten, _copy_from) \ +_(aten, _copy_from_and_resize) \ +_(aten, _cslt_compress) \ +_(aten, _cslt_sparse_mm) \ +_(aten, _cslt_sparse_mm_search) \ +_(aten, _ctc_loss) \ +_(aten, _ctc_loss_backward) \ +_(aten, _cudnn_attention_forward) \ +_(aten, _cudnn_ctc_loss) \ +_(aten, _cudnn_init_dropout_state) \ +_(aten, _cudnn_rnn) \ +_(aten, _cudnn_rnn_backward) \ +_(aten, _cudnn_rnn_flatten_weight) \ +_(aten, _cufft_clear_plan_cache) \ +_(aten, _cufft_get_plan_cache_max_size) \ +_(aten, _cufft_get_plan_cache_size) \ +_(aten, _cufft_set_plan_cache_max_size) \ +_(aten, _cummax_helper) \ +_(aten, _cummin_helper) \ +_(aten, _debug_has_internal_overlap) \ +_(aten, _dimI) \ +_(aten, _dimV) \ +_(aten, _dim_arange) \ +_(aten, _dirichlet_grad) \ +_(aten, _dyn_quant_matmul_4bit) \ +_(aten, _dyn_quant_pack_4bit_weight) \ +_(aten, _efficient_attention_backward) \ +_(aten, _efficient_attention_forward) \ +_(aten, _efficientzerotensor) \ +_(aten, _embedding_bag) \ +_(aten, _embedding_bag_backward) \ +_(aten, _embedding_bag_dense_backward) \ +_(aten, _embedding_bag_forward_only) \ +_(aten, _embedding_bag_per_sample_weights_backward) \ +_(aten, _embedding_bag_sparse_backward) \ +_(aten, _empty_affine_quantized) \ +_(aten, _empty_per_channel_affine_quantized) \ +_(aten, _euclidean_dist) \ +_(aten, _fake_quantize_learnable_per_channel_affine) \ +_(aten, _fake_quantize_learnable_per_channel_affine_backward) \ +_(aten, _fake_quantize_learnable_per_tensor_affine) \ +_(aten, _fake_quantize_learnable_per_tensor_affine_backward) \ +_(aten, _fake_quantize_per_tensor_affine_cachemask_tensor_qparams) \ +_(aten, _fft_c2c) \ +_(aten, _fft_c2r) \ +_(aten, _fft_r2c) \ +_(aten, _fill_mem_eff_dropout_mask) \ +_(aten, _fill_mem_eff_dropout_mask_) \ +_(aten, _flash_attention_backward) \ +_(aten, _flash_attention_forward) \ +_(aten, _foobar) \ +_(aten, _foreach_abs) \ +_(aten, _foreach_abs_) \ +_(aten, _foreach_acos) \ +_(aten, _foreach_acos_) \ +_(aten, _foreach_add) \ +_(aten, _foreach_add_) \ +_(aten, _foreach_addcdiv) \ +_(aten, _foreach_addcdiv_) \ +_(aten, _foreach_addcmul) \ +_(aten, _foreach_addcmul_) \ +_(aten, _foreach_asin) \ +_(aten, _foreach_asin_) \ +_(aten, _foreach_atan) \ +_(aten, _foreach_atan_) \ +_(aten, _foreach_ceil) \ +_(aten, _foreach_ceil_) \ +_(aten, _foreach_clamp_max) \ +_(aten, _foreach_clamp_max_) \ +_(aten, _foreach_clamp_min) \ +_(aten, _foreach_clamp_min_) \ +_(aten, _foreach_copy) \ +_(aten, _foreach_copy_) \ +_(aten, _foreach_cos) \ +_(aten, _foreach_cos_) \ +_(aten, _foreach_cosh) \ +_(aten, _foreach_cosh_) \ +_(aten, _foreach_div) \ +_(aten, _foreach_div_) \ +_(aten, _foreach_erf) \ +_(aten, _foreach_erf_) \ +_(aten, _foreach_erfc) \ +_(aten, _foreach_erfc_) \ +_(aten, _foreach_exp) \ +_(aten, _foreach_exp_) \ +_(aten, _foreach_expm1) \ +_(aten, _foreach_expm1_) \ +_(aten, _foreach_floor) \ +_(aten, _foreach_floor_) \ +_(aten, _foreach_frac) \ +_(aten, _foreach_frac_) \ +_(aten, _foreach_lerp) \ +_(aten, _foreach_lerp_) \ +_(aten, _foreach_lgamma) \ +_(aten, _foreach_lgamma_) \ +_(aten, _foreach_log) \ +_(aten, _foreach_log10) \ +_(aten, _foreach_log10_) \ +_(aten, _foreach_log1p) \ +_(aten, _foreach_log1p_) \ +_(aten, _foreach_log2) \ +_(aten, _foreach_log2_) \ +_(aten, _foreach_log_) \ +_(aten, _foreach_max) \ +_(aten, _foreach_maximum) \ +_(aten, _foreach_maximum_) \ +_(aten, _foreach_minimum) \ +_(aten, _foreach_minimum_) \ +_(aten, _foreach_mul) \ +_(aten, _foreach_mul_) \ +_(aten, _foreach_neg) \ +_(aten, _foreach_neg_) \ +_(aten, _foreach_norm) \ +_(aten, _foreach_pow) \ +_(aten, _foreach_pow_) \ +_(aten, _foreach_reciprocal) \ +_(aten, _foreach_reciprocal_) \ +_(aten, _foreach_round) \ +_(aten, _foreach_round_) \ +_(aten, _foreach_rsqrt) \ +_(aten, _foreach_rsqrt_) \ +_(aten, _foreach_sigmoid) \ +_(aten, _foreach_sigmoid_) \ +_(aten, _foreach_sign) \ +_(aten, _foreach_sign_) \ +_(aten, _foreach_sin) \ +_(aten, _foreach_sin_) \ +_(aten, _foreach_sinh) \ +_(aten, _foreach_sinh_) \ +_(aten, _foreach_sqrt) \ +_(aten, _foreach_sqrt_) \ +_(aten, _foreach_sub) \ +_(aten, _foreach_sub_) \ +_(aten, _foreach_tan) \ +_(aten, _foreach_tan_) \ +_(aten, _foreach_tanh) \ +_(aten, _foreach_tanh_) \ +_(aten, _foreach_trunc) \ +_(aten, _foreach_trunc_) \ +_(aten, _foreach_zero) \ +_(aten, _foreach_zero_) \ +_(aten, _functional_assert_async) \ +_(aten, _functional_assert_scalar) \ +_(aten, _functional_sym_constrain_range) \ +_(aten, _functional_sym_constrain_range_for_size) \ +_(aten, _fused_adagrad) \ +_(aten, _fused_adagrad_) \ +_(aten, _fused_adam) \ +_(aten, _fused_adam_) \ +_(aten, _fused_adamw) \ +_(aten, _fused_adamw_) \ +_(aten, _fused_dropout) \ +_(aten, _fused_moving_avg_obs_fq_helper) \ +_(aten, _fused_moving_avg_obs_fq_helper_functional) \ +_(aten, _fused_rms_norm) \ +_(aten, _fused_sdp_choice) \ +_(aten, _fused_sgd) \ +_(aten, _fused_sgd_) \ +_(aten, _fw_primal) \ +_(aten, _fw_primal_copy) \ +_(aten, _gather_sparse_backward) \ +_(aten, _grid_sampler_2d_cpu_fallback) \ +_(aten, _grid_sampler_2d_cpu_fallback_backward) \ +_(aten, _grouped_mm) \ +_(aten, _has_compatible_shallow_copy_type) \ +_(aten, _has_same_storage_numel) \ +_(aten, _histogramdd_bin_edges) \ +_(aten, _histogramdd_from_bin_cts) \ +_(aten, _histogramdd_from_bin_tensors) \ +_(aten, _index_put_impl) \ +_(aten, _index_put_impl_) \ +_(aten, _indices) \ +_(aten, _indices_copy) \ +_(aten, _int_mm) \ +_(aten, _is_all_true) \ +_(aten, _is_any_true) \ +_(aten, _is_zerotensor) \ +_(aten, _jagged_to_padded_dense_forward) \ +_(aten, _lazy_clone) \ +_(aten, _linalg_check_errors) \ +_(aten, _linalg_det) \ +_(aten, _linalg_eigh) \ +_(aten, _linalg_eigvals) \ +_(aten, _linalg_slogdet) \ +_(aten, _linalg_solve_ex) \ +_(aten, _linalg_svd) \ +_(aten, _local_scalar_dense) \ +_(aten, _log_softmax) \ +_(aten, _log_softmax_backward_data) \ +_(aten, _logcumsumexp) \ +_(aten, _lstm_mps) \ +_(aten, _lu_with_info) \ +_(aten, _make_dep_token) \ +_(aten, _make_dual) \ +_(aten, _make_dual_copy) \ +_(aten, _make_per_channel_quantized_tensor) \ +_(aten, _make_per_tensor_quantized_tensor) \ +_(aten, _masked_scale) \ +_(aten, _masked_softmax) \ +_(aten, _masked_softmax_backward) \ +_(aten, _mixed_dtypes_linear) \ +_(aten, _mkldnn_reshape) \ +_(aten, _mkldnn_transpose) \ +_(aten, _mkldnn_transpose_) \ +_(aten, _mps_convolution) \ +_(aten, _mps_convolution_transpose) \ +_(aten, _native_batch_norm_legit) \ +_(aten, _native_batch_norm_legit_functional) \ +_(aten, _native_batch_norm_legit_no_training) \ +_(aten, _native_multi_head_attention) \ +_(aten, _neg_view) \ +_(aten, _neg_view_copy) \ +_(aten, _nested_compute_contiguous_strides_offsets) \ +_(aten, _nested_from_padded) \ +_(aten, _nested_from_padded_and_nested_example) \ +_(aten, _nested_from_padded_tensor) \ +_(aten, _nested_get_jagged_dummy) \ +_(aten, _nested_get_lengths) \ +_(aten, _nested_get_max_seqlen) \ +_(aten, _nested_get_min_seqlen) \ +_(aten, _nested_get_offsets) \ +_(aten, _nested_get_ragged_idx) \ +_(aten, _nested_get_values) \ +_(aten, _nested_get_values_copy) \ +_(aten, _nested_select_backward) \ +_(aten, _nested_sum_backward) \ +_(aten, _nested_tensor_from_mask) \ +_(aten, _nested_tensor_from_mask_left_aligned) \ +_(aten, _nested_tensor_from_tensor_list) \ +_(aten, _nested_tensor_size) \ +_(aten, _nested_tensor_softmax_with_shape) \ +_(aten, _nested_tensor_storage_offsets) \ +_(aten, _nested_tensor_strides) \ +_(aten, _nested_view_from_buffer) \ +_(aten, _nested_view_from_buffer_copy) \ +_(aten, _nested_view_from_jagged) \ +_(aten, _nested_view_from_jagged_copy) \ +_(aten, _new_zeros_with_same_feature_meta) \ +_(aten, _nnpack_available) \ +_(aten, _nnpack_spatial_convolution) \ +_(aten, _nnz) \ +_(aten, _pack_padded_sequence) \ +_(aten, _pack_padded_sequence_backward) \ +_(aten, _pad_circular) \ +_(aten, _pad_enum) \ +_(aten, _pad_packed_sequence) \ +_(aten, _padded_dense_to_jagged_forward) \ +_(aten, _pdist_backward) \ +_(aten, _pdist_forward) \ +_(aten, _pin_memory) \ +_(aten, _prelu_kernel) \ +_(aten, _prelu_kernel_backward) \ +_(aten, _print) \ +_(aten, _propagate_xla_data) \ +_(aten, _remove_batch_dim) \ +_(aten, _reshape_alias) \ +_(aten, _reshape_alias_copy) \ +_(aten, _reshape_copy) \ +_(aten, _reshape_from_tensor) \ +_(aten, _resize_output) \ +_(aten, _resize_output_) \ +_(aten, _rowwise_prune) \ +_(aten, _safe_softmax) \ +_(aten, _sample_dirichlet) \ +_(aten, _saturate_weight_to_fp16) \ +_(aten, _scaled_dot_product_attention_math) \ +_(aten, _scaled_dot_product_attention_math_for_mps) \ +_(aten, _scaled_dot_product_cudnn_attention) \ +_(aten, _scaled_dot_product_cudnn_attention_backward) \ +_(aten, _scaled_dot_product_efficient_attention) \ +_(aten, _scaled_dot_product_efficient_attention_backward) \ +_(aten, _scaled_dot_product_flash_attention) \ +_(aten, _scaled_dot_product_flash_attention_backward) \ +_(aten, _scaled_dot_product_flash_attention_for_cpu) \ +_(aten, _scaled_dot_product_flash_attention_for_cpu_backward) \ +_(aten, _scaled_dot_product_fused_attention_overrideable) \ +_(aten, _scaled_dot_product_fused_attention_overrideable_backward) \ +_(aten, _scaled_grouped_mm) \ +_(aten, _scaled_mm) \ +_(aten, _segment_reduce_backward) \ +_(aten, _shape_as_tensor) \ +_(aten, _slow_conv2d_backward) \ +_(aten, _slow_conv2d_forward) \ +_(aten, _sobol_engine_draw) \ +_(aten, _sobol_engine_ff) \ +_(aten, _sobol_engine_ff_) \ +_(aten, _sobol_engine_initialize_state) \ +_(aten, _sobol_engine_initialize_state_) \ +_(aten, _sobol_engine_scramble) \ +_(aten, _sobol_engine_scramble_) \ +_(aten, _softmax) \ +_(aten, _softmax_backward_data) \ +_(aten, _sparse_addmm) \ +_(aten, _sparse_broadcast_to) \ +_(aten, _sparse_broadcast_to_copy) \ +_(aten, _sparse_bsc_tensor_unsafe) \ +_(aten, _sparse_bsr_tensor_unsafe) \ +_(aten, _sparse_compressed_tensor_unsafe) \ +_(aten, _sparse_compressed_tensor_with_dims) \ +_(aten, _sparse_coo_tensor_unsafe) \ +_(aten, _sparse_coo_tensor_with_dims) \ +_(aten, _sparse_coo_tensor_with_dims_and_tensors) \ +_(aten, _sparse_csc_tensor_unsafe) \ +_(aten, _sparse_csr_prod) \ +_(aten, _sparse_csr_sum) \ +_(aten, _sparse_csr_tensor_unsafe) \ +_(aten, _sparse_log_softmax) \ +_(aten, _sparse_log_softmax_backward_data) \ +_(aten, _sparse_mask_projection) \ +_(aten, _sparse_mm) \ +_(aten, _sparse_mm_reduce_impl) \ +_(aten, _sparse_mm_reduce_impl_backward) \ +_(aten, _sparse_semi_structured_addmm) \ +_(aten, _sparse_semi_structured_apply) \ +_(aten, _sparse_semi_structured_apply_dense) \ +_(aten, _sparse_semi_structured_linear) \ +_(aten, _sparse_semi_structured_mm) \ +_(aten, _sparse_semi_structured_tile) \ +_(aten, _sparse_softmax) \ +_(aten, _sparse_softmax_backward_data) \ +_(aten, _sparse_sparse_matmul) \ +_(aten, _sparse_sum) \ +_(aten, _sparse_sum_backward) \ +_(aten, _spdiags) \ +_(aten, _spsolve) \ +_(aten, _stack) \ +_(aten, _standard_gamma) \ +_(aten, _standard_gamma_grad) \ +_(aten, _test_ambiguous_defaults) \ +_(aten, _test_autograd_multiple_dispatch) \ +_(aten, _test_autograd_multiple_dispatch_view) \ +_(aten, _test_autograd_multiple_dispatch_view_copy) \ +_(aten, _test_check_tensor) \ +_(aten, _test_functorch_fallback) \ +_(aten, _test_optional_filled_intlist) \ +_(aten, _test_optional_floatlist) \ +_(aten, _test_optional_intlist) \ +_(aten, _test_parallel_materialize) \ +_(aten, _test_serialization_subcmul) \ +_(aten, _test_string_default) \ +_(aten, _test_warn_in_autograd) \ +_(aten, _thnn_differentiable_gru_cell_backward) \ +_(aten, _thnn_differentiable_lstm_cell_backward) \ +_(aten, _thnn_fused_gru_cell) \ +_(aten, _thnn_fused_gru_cell_backward) \ +_(aten, _thnn_fused_lstm_cell) \ +_(aten, _thnn_fused_lstm_cell_backward) \ +_(aten, _thnn_fused_lstm_cell_backward_impl) \ +_(aten, _to_copy) \ +_(aten, _to_cpu) \ +_(aten, _to_dense) \ +_(aten, _to_sparse) \ +_(aten, _to_sparse_bsc) \ +_(aten, _to_sparse_bsr) \ +_(aten, _to_sparse_csc) \ +_(aten, _to_sparse_csr) \ +_(aten, _to_sparse_semi_structured) \ +_(aten, _transform_bias_rescale_qkv) \ +_(aten, _transformer_encoder_layer_fwd) \ +_(aten, _trilinear) \ +_(aten, _triton_multi_head_attention) \ +_(aten, _triton_scaled_dot_attention) \ +_(aten, _unique) \ +_(aten, _unique2) \ +_(aten, _unpack_dual) \ +_(aten, _unsafe_index) \ +_(aten, _unsafe_index_put) \ +_(aten, _unsafe_masked_index) \ +_(aten, _unsafe_masked_index_put_accumulate) \ +_(aten, _unsafe_view) \ +_(aten, _upsample_bicubic2d_aa) \ +_(aten, _upsample_bicubic2d_aa_backward) \ +_(aten, _upsample_bilinear2d_aa) \ +_(aten, _upsample_bilinear2d_aa_backward) \ +_(aten, _upsample_nearest_exact1d) \ +_(aten, _upsample_nearest_exact1d_backward) \ +_(aten, _upsample_nearest_exact2d) \ +_(aten, _upsample_nearest_exact2d_backward) \ +_(aten, _upsample_nearest_exact3d) \ +_(aten, _upsample_nearest_exact3d_backward) \ +_(aten, _use_cudnn_ctc_loss) \ +_(aten, _use_cudnn_rnn_flatten_weight) \ +_(aten, _validate_compressed_sparse_indices) \ +_(aten, _validate_sparse_bsc_tensor_args) \ +_(aten, _validate_sparse_bsr_tensor_args) \ +_(aten, _validate_sparse_compressed_tensor_args) \ +_(aten, _validate_sparse_coo_tensor_args) \ +_(aten, _validate_sparse_csc_tensor_args) \ +_(aten, _validate_sparse_csr_tensor_args) \ +_(aten, _values) \ +_(aten, _values_copy) \ +_(aten, _version) \ +_(aten, _weight_int4pack_mm) \ +_(aten, _weight_int4pack_mm_for_cpu) \ +_(aten, _weight_int4pack_mm_with_scales_and_zeros) \ +_(aten, _weight_int8pack_mm) \ +_(aten, _weight_norm) \ +_(aten, _weight_norm_differentiable_backward) \ +_(aten, _weight_norm_interface) \ +_(aten, _weight_norm_interface_backward) \ +_(aten, _wrapped_linear_prepack) \ +_(aten, _wrapped_quantized_linear_prepacked) \ +_(aten, abs) \ +_(aten, abs_) \ +_(aten, absolute) \ +_(aten, absolute_) \ +_(aten, acos) \ +_(aten, acos_) \ +_(aten, acosh) \ +_(aten, acosh_) \ +_(aten, adaptive_avg_pool1d) \ +_(aten, adaptive_avg_pool2d) \ +_(aten, adaptive_avg_pool3d) \ +_(aten, adaptive_avg_pool3d_backward) \ +_(aten, adaptive_max_pool1d) \ +_(aten, adaptive_max_pool2d) \ +_(aten, adaptive_max_pool2d_backward) \ +_(aten, adaptive_max_pool3d) \ +_(aten, adaptive_max_pool3d_backward) \ +_(aten, add) \ +_(aten, add_) \ +_(aten, addbmm) \ +_(aten, addbmm_) \ +_(aten, addcdiv) \ +_(aten, addcdiv_) \ +_(aten, addcmul) \ +_(aten, addcmul_) \ +_(aten, addmm) \ +_(aten, addmm_) \ +_(aten, addmv) \ +_(aten, addmv_) \ +_(aten, addr) \ +_(aten, addr_) \ +_(aten, adjoint) \ +_(aten, affine_grid_generator) \ +_(aten, affine_grid_generator_backward) \ +_(aten, alias) \ +_(aten, alias_copy) \ +_(aten, align_as) \ +_(aten, align_tensors) \ +_(aten, align_to) \ +_(aten, all) \ +_(aten, allclose) \ +_(aten, alpha_dropout) \ +_(aten, alpha_dropout_) \ +_(aten, amax) \ +_(aten, amin) \ +_(aten, aminmax) \ +_(aten, angle) \ +_(aten, any) \ +_(aten, arange) \ +_(aten, arccos) \ +_(aten, arccos_) \ +_(aten, arccosh) \ +_(aten, arccosh_) \ +_(aten, arcsin) \ +_(aten, arcsin_) \ +_(aten, arcsinh) \ +_(aten, arcsinh_) \ +_(aten, arctan) \ +_(aten, arctan2) \ +_(aten, arctan2_) \ +_(aten, arctan_) \ +_(aten, arctanh) \ +_(aten, arctanh_) \ +_(aten, argmax) \ +_(aten, argmin) \ +_(aten, argsort) \ +_(aten, argwhere) \ +_(aten, as_strided) \ +_(aten, as_strided_) \ +_(aten, as_strided_copy) \ +_(aten, as_strided_scatter) \ +_(aten, asin) \ +_(aten, asin_) \ +_(aten, asinh) \ +_(aten, asinh_) \ +_(aten, atan) \ +_(aten, atan2) \ +_(aten, atan2_) \ +_(aten, atan_) \ +_(aten, atanh) \ +_(aten, atanh_) \ +_(aten, atleast_1d) \ +_(aten, atleast_2d) \ +_(aten, atleast_3d) \ +_(aten, avg_pool1d) \ +_(aten, avg_pool2d) \ +_(aten, avg_pool2d_backward) \ +_(aten, avg_pool3d) \ +_(aten, avg_pool3d_backward) \ +_(aten, baddbmm) \ +_(aten, baddbmm_) \ +_(aten, bartlett_window) \ +_(aten, batch_norm) \ +_(aten, batch_norm_backward) \ +_(aten, batch_norm_backward_elemt) \ +_(aten, batch_norm_backward_reduce) \ +_(aten, batch_norm_elemt) \ +_(aten, batch_norm_gather_stats) \ +_(aten, batch_norm_gather_stats_with_counts) \ +_(aten, batch_norm_stats) \ +_(aten, batch_norm_update_stats) \ +_(aten, bernoulli) \ +_(aten, bernoulli_) \ +_(aten, bilinear) \ +_(aten, binary_cross_entropy) \ +_(aten, binary_cross_entropy_backward) \ +_(aten, binary_cross_entropy_with_logits) \ +_(aten, bincount) \ +_(aten, binomial) \ +_(aten, bitwise_and) \ +_(aten, bitwise_and_) \ +_(aten, bitwise_left_shift) \ +_(aten, bitwise_left_shift_) \ +_(aten, bitwise_not) \ +_(aten, bitwise_not_) \ +_(aten, bitwise_or) \ +_(aten, bitwise_or_) \ +_(aten, bitwise_right_shift) \ +_(aten, bitwise_right_shift_) \ +_(aten, bitwise_xor) \ +_(aten, bitwise_xor_) \ +_(aten, blackman_window) \ +_(aten, block_diag) \ +_(aten, bmm) \ +_(aten, broadcast_tensors) \ +_(aten, broadcast_to) \ +_(aten, bucketize) \ +_(aten, can_cast) \ +_(aten, cartesian_prod) \ +_(aten, cat) \ +_(aten, cauchy) \ +_(aten, cauchy_) \ +_(aten, ccol_indices) \ +_(aten, ccol_indices_copy) \ +_(aten, cdist) \ +_(aten, ceil) \ +_(aten, ceil_) \ +_(aten, celu) \ +_(aten, celu_) \ +_(aten, chain_matmul) \ +_(aten, chalf) \ +_(aten, channel_shuffle) \ +_(aten, cholesky) \ +_(aten, cholesky_inverse) \ +_(aten, cholesky_solve) \ +_(aten, choose_qparams_optimized) \ +_(aten, chunk) \ +_(aten, clamp) \ +_(aten, clamp_) \ +_(aten, clamp_max) \ +_(aten, clamp_max_) \ +_(aten, clamp_min) \ +_(aten, clamp_min_) \ +_(aten, clip) \ +_(aten, clip_) \ +_(aten, clone) \ +_(aten, coalesce) \ +_(aten, col2im) \ +_(aten, col_indices) \ +_(aten, col_indices_copy) \ +_(aten, column_stack) \ +_(aten, combinations) \ +_(aten, complex) \ +_(aten, concat) \ +_(aten, concatenate) \ +_(aten, conj) \ +_(aten, conj_physical) \ +_(aten, conj_physical_) \ +_(aten, constant_pad_nd) \ +_(aten, contiguous) \ +_(aten, conv1d) \ +_(aten, conv2d) \ +_(aten, conv3d) \ +_(aten, conv_depthwise3d) \ +_(aten, conv_tbc) \ +_(aten, conv_tbc_backward) \ +_(aten, conv_transpose1d) \ +_(aten, conv_transpose2d) \ +_(aten, conv_transpose3d) \ +_(aten, convolution) \ +_(aten, convolution_backward) \ +_(aten, convolution_backward_overrideable) \ +_(aten, convolution_overrideable) \ +_(aten, copy) \ +_(aten, copy_) \ +_(aten, copy_sparse_to_sparse) \ +_(aten, copy_sparse_to_sparse_) \ +_(aten, copysign) \ +_(aten, copysign_) \ +_(aten, corrcoef) \ +_(aten, cos) \ +_(aten, cos_) \ +_(aten, cosh) \ +_(aten, cosh_) \ +_(aten, cosine_embedding_loss) \ +_(aten, cosine_similarity) \ +_(aten, count_nonzero) \ +_(aten, cov) \ +_(aten, cross) \ +_(aten, cross_entropy_loss) \ +_(aten, crow_indices) \ +_(aten, crow_indices_copy) \ +_(aten, ctc_loss) \ +_(aten, cudnn_affine_grid_generator) \ +_(aten, cudnn_affine_grid_generator_backward) \ +_(aten, cudnn_batch_norm) \ +_(aten, cudnn_batch_norm_backward) \ +_(aten, cudnn_convolution) \ +_(aten, cudnn_convolution_add_relu) \ +_(aten, cudnn_convolution_relu) \ +_(aten, cudnn_convolution_transpose) \ +_(aten, cudnn_grid_sampler) \ +_(aten, cudnn_grid_sampler_backward) \ +_(aten, cudnn_is_acceptable) \ +_(aten, cummax) \ +_(aten, cummaxmin_backward) \ +_(aten, cummin) \ +_(aten, cumprod) \ +_(aten, cumprod_) \ +_(aten, cumprod_backward) \ +_(aten, cumsum) \ +_(aten, cumsum_) \ +_(aten, cumulative_trapezoid) \ +_(aten, data) \ +_(aten, deg2rad) \ +_(aten, deg2rad_) \ +_(aten, dense_dim) \ +_(aten, dequantize) \ +_(aten, det) \ +_(aten, detach) \ +_(aten, detach_) \ +_(aten, detach_copy) \ +_(aten, diag) \ +_(aten, diag_embed) \ +_(aten, diagflat) \ +_(aten, diagonal) \ +_(aten, diagonal_backward) \ +_(aten, diagonal_copy) \ +_(aten, diagonal_scatter) \ +_(aten, diff) \ +_(aten, digamma) \ +_(aten, digamma_) \ +_(aten, dist) \ +_(aten, div) \ +_(aten, div_) \ +_(aten, divide) \ +_(aten, divide_) \ +_(aten, dot) \ +_(aten, dropout) \ +_(aten, dropout_) \ +_(aten, dsplit) \ +_(aten, dstack) \ +_(aten, einsum) \ +_(aten, elu) \ +_(aten, elu_) \ +_(aten, elu_backward) \ +_(aten, embedding) \ +_(aten, embedding_backward) \ +_(aten, embedding_bag) \ +_(aten, embedding_dense_backward) \ +_(aten, embedding_renorm) \ +_(aten, embedding_renorm_) \ +_(aten, embedding_sparse_backward) \ +_(aten, empty) \ +_(aten, empty_like) \ +_(aten, empty_permuted) \ +_(aten, empty_quantized) \ +_(aten, empty_strided) \ +_(aten, eq) \ +_(aten, eq_) \ +_(aten, equal) \ +_(aten, erf) \ +_(aten, erf_) \ +_(aten, erfc) \ +_(aten, erfc_) \ +_(aten, erfinv) \ +_(aten, erfinv_) \ +_(aten, exp) \ +_(aten, exp2) \ +_(aten, exp2_) \ +_(aten, exp_) \ +_(aten, expand) \ +_(aten, expand_as) \ +_(aten, expand_copy) \ +_(aten, expm1) \ +_(aten, expm1_) \ +_(aten, exponential) \ +_(aten, exponential_) \ +_(aten, eye) \ +_(aten, fake_quantize_per_channel_affine) \ +_(aten, fake_quantize_per_channel_affine_cachemask) \ +_(aten, fake_quantize_per_channel_affine_cachemask_backward) \ +_(aten, fake_quantize_per_tensor_affine) \ +_(aten, fake_quantize_per_tensor_affine_cachemask) \ +_(aten, fake_quantize_per_tensor_affine_cachemask_backward) \ +_(aten, fbgemm_linear_fp16_weight) \ +_(aten, fbgemm_linear_fp16_weight_fp32_activation) \ +_(aten, fbgemm_linear_int8_weight) \ +_(aten, fbgemm_linear_int8_weight_fp32_activation) \ +_(aten, fbgemm_linear_quantize_weight) \ +_(aten, fbgemm_pack_gemm_matrix_fp16) \ +_(aten, fbgemm_pack_quantized_matrix) \ +_(aten, feature_alpha_dropout) \ +_(aten, feature_alpha_dropout_) \ +_(aten, feature_dropout) \ +_(aten, feature_dropout_) \ +_(aten, fft_fft) \ +_(aten, fft_fft2) \ +_(aten, fft_fftfreq) \ +_(aten, fft_fftn) \ +_(aten, fft_fftshift) \ +_(aten, fft_hfft) \ +_(aten, fft_hfft2) \ +_(aten, fft_hfftn) \ +_(aten, fft_ifft) \ +_(aten, fft_ifft2) \ +_(aten, fft_ifftn) \ +_(aten, fft_ifftshift) \ +_(aten, fft_ihfft) \ +_(aten, fft_ihfft2) \ +_(aten, fft_ihfftn) \ +_(aten, fft_irfft) \ +_(aten, fft_irfft2) \ +_(aten, fft_irfftn) \ +_(aten, fft_rfft) \ +_(aten, fft_rfft2) \ +_(aten, fft_rfftfreq) \ +_(aten, fft_rfftn) \ +_(aten, fill) \ +_(aten, fill_) \ +_(aten, fill_diagonal) \ +_(aten, fill_diagonal_) \ +_(aten, fix) \ +_(aten, fix_) \ +_(aten, flatten) \ +_(aten, flatten_dense_tensors) \ +_(aten, flip) \ +_(aten, fliplr) \ +_(aten, flipud) \ +_(aten, float_power) \ +_(aten, float_power_) \ +_(aten, floor) \ +_(aten, floor_) \ +_(aten, floor_divide) \ +_(aten, floor_divide_) \ +_(aten, fmax) \ +_(aten, fmin) \ +_(aten, fmod) \ +_(aten, fmod_) \ +_(aten, frac) \ +_(aten, frac_) \ +_(aten, fractional_max_pool2d) \ +_(aten, fractional_max_pool2d_backward) \ +_(aten, fractional_max_pool3d) \ +_(aten, fractional_max_pool3d_backward) \ +_(aten, frexp) \ +_(aten, frobenius_norm) \ +_(aten, from_file) \ +_(aten, full) \ +_(aten, full_like) \ +_(aten, fused_moving_avg_obs_fake_quant) \ +_(aten, gather) \ +_(aten, gather_backward) \ +_(aten, gcd) \ +_(aten, gcd_) \ +_(aten, ge) \ +_(aten, ge_) \ +_(aten, gelu) \ +_(aten, gelu_) \ +_(aten, gelu_backward) \ +_(aten, geometric) \ +_(aten, geometric_) \ +_(aten, geqrf) \ +_(aten, ger) \ +_(aten, glu) \ +_(aten, glu_backward) \ +_(aten, glu_backward_jvp) \ +_(aten, glu_jvp) \ +_(aten, gradient) \ +_(aten, greater) \ +_(aten, greater_) \ +_(aten, greater_equal) \ +_(aten, greater_equal_) \ +_(aten, grid_sampler) \ +_(aten, grid_sampler_2d) \ +_(aten, grid_sampler_2d_backward) \ +_(aten, grid_sampler_3d) \ +_(aten, grid_sampler_3d_backward) \ +_(aten, group_norm) \ +_(aten, gru) \ +_(aten, gru_cell) \ +_(aten, gt) \ +_(aten, gt_) \ +_(aten, hamming_window) \ +_(aten, hann_window) \ +_(aten, hardshrink) \ +_(aten, hardshrink_backward) \ +_(aten, hardsigmoid) \ +_(aten, hardsigmoid_) \ +_(aten, hardsigmoid_backward) \ +_(aten, hardswish) \ +_(aten, hardswish_) \ +_(aten, hardswish_backward) \ +_(aten, hardtanh) \ +_(aten, hardtanh_) \ +_(aten, hardtanh_backward) \ +_(aten, heaviside) \ +_(aten, heaviside_) \ +_(aten, hinge_embedding_loss) \ +_(aten, histc) \ +_(aten, histogram) \ +_(aten, histogramdd) \ +_(aten, hsplit) \ +_(aten, hspmm) \ +_(aten, hstack) \ +_(aten, huber_loss) \ +_(aten, huber_loss_backward) \ +_(aten, hypot) \ +_(aten, hypot_) \ +_(aten, i0) \ +_(aten, i0_) \ +_(aten, igamma) \ +_(aten, igamma_) \ +_(aten, igammac) \ +_(aten, igammac_) \ +_(aten, im2col) \ +_(aten, imag) \ +_(aten, index) \ +_(aten, index_add) \ +_(aten, index_add_) \ +_(aten, index_copy) \ +_(aten, index_copy_) \ +_(aten, index_fill) \ +_(aten, index_fill_) \ +_(aten, index_put) \ +_(aten, index_put_) \ +_(aten, index_reduce) \ +_(aten, index_reduce_) \ +_(aten, index_select) \ +_(aten, index_select_backward) \ +_(aten, indices) \ +_(aten, indices_copy) \ +_(aten, infinitely_differentiable_gelu_backward) \ +_(aten, inner) \ +_(aten, instance_norm) \ +_(aten, int_repr) \ +_(aten, inverse) \ +_(aten, is_coalesced) \ +_(aten, is_complex) \ +_(aten, is_conj) \ +_(aten, is_distributed) \ +_(aten, is_floating_point) \ +_(aten, is_inference) \ +_(aten, is_leaf) \ +_(aten, is_neg) \ +_(aten, is_nonzero) \ +_(aten, is_pinned) \ +_(aten, is_same_size) \ +_(aten, is_set_to) \ +_(aten, is_signed) \ +_(aten, is_vulkan_available) \ +_(aten, isclose) \ +_(aten, isfinite) \ +_(aten, isin) \ +_(aten, isinf) \ +_(aten, isnan) \ +_(aten, isneginf) \ +_(aten, isposinf) \ +_(aten, isreal) \ +_(aten, istft) \ +_(aten, item) \ +_(aten, kaiser_window) \ +_(aten, kl_div) \ +_(aten, kron) \ +_(aten, kthvalue) \ +_(aten, l1_loss) \ +_(aten, layer_norm) \ +_(aten, lcm) \ +_(aten, lcm_) \ +_(aten, ldexp) \ +_(aten, ldexp_) \ +_(aten, le) \ +_(aten, le_) \ +_(aten, leaky_relu) \ +_(aten, leaky_relu_) \ +_(aten, leaky_relu_backward) \ +_(aten, lerp) \ +_(aten, lerp_) \ +_(aten, less) \ +_(aten, less_) \ +_(aten, less_equal) \ +_(aten, less_equal_) \ +_(aten, lgamma) \ +_(aten, lgamma_) \ +_(aten, lift) \ +_(aten, lift_fresh) \ +_(aten, lift_fresh_copy) \ +_(aten, linalg_cholesky) \ +_(aten, linalg_cholesky_ex) \ +_(aten, linalg_cond) \ +_(aten, linalg_cross) \ +_(aten, linalg_det) \ +_(aten, linalg_diagonal) \ +_(aten, linalg_eig) \ +_(aten, linalg_eigh) \ +_(aten, linalg_eigvals) \ +_(aten, linalg_eigvalsh) \ +_(aten, linalg_householder_product) \ +_(aten, linalg_inv) \ +_(aten, linalg_inv_ex) \ +_(aten, linalg_ldl_factor) \ +_(aten, linalg_ldl_factor_ex) \ +_(aten, linalg_ldl_solve) \ +_(aten, linalg_lstsq) \ +_(aten, linalg_lu) \ +_(aten, linalg_lu_factor) \ +_(aten, linalg_lu_factor_ex) \ +_(aten, linalg_lu_solve) \ +_(aten, linalg_matmul) \ +_(aten, linalg_matrix_exp) \ +_(aten, linalg_matrix_norm) \ +_(aten, linalg_matrix_power) \ +_(aten, linalg_matrix_rank) \ +_(aten, linalg_multi_dot) \ +_(aten, linalg_norm) \ +_(aten, linalg_pinv) \ +_(aten, linalg_qr) \ +_(aten, linalg_slogdet) \ +_(aten, linalg_solve) \ +_(aten, linalg_solve_ex) \ +_(aten, linalg_solve_triangular) \ +_(aten, linalg_svd) \ +_(aten, linalg_svdvals) \ +_(aten, linalg_tensorinv) \ +_(aten, linalg_tensorsolve) \ +_(aten, linalg_vander) \ +_(aten, linalg_vecdot) \ +_(aten, linalg_vector_norm) \ +_(aten, linear) \ +_(aten, linear_backward) \ +_(aten, linspace) \ +_(aten, log) \ +_(aten, log10) \ +_(aten, log10_) \ +_(aten, log1p) \ +_(aten, log1p_) \ +_(aten, log2) \ +_(aten, log2_) \ +_(aten, log_) \ +_(aten, log_normal) \ +_(aten, log_normal_) \ +_(aten, log_sigmoid) \ +_(aten, log_sigmoid_backward) \ +_(aten, log_sigmoid_forward) \ +_(aten, log_softmax) \ +_(aten, logaddexp) \ +_(aten, logaddexp2) \ +_(aten, logcumsumexp) \ +_(aten, logdet) \ +_(aten, logical_and) \ +_(aten, logical_and_) \ +_(aten, logical_not) \ +_(aten, logical_not_) \ +_(aten, logical_or) \ +_(aten, logical_or_) \ +_(aten, logical_xor) \ +_(aten, logical_xor_) \ +_(aten, logit) \ +_(aten, logit_) \ +_(aten, logit_backward) \ +_(aten, logspace) \ +_(aten, logsumexp) \ +_(aten, lshift) \ +_(aten, lstm) \ +_(aten, lstm_cell) \ +_(aten, lstm_mps_backward) \ +_(aten, lt) \ +_(aten, lt_) \ +_(aten, lu_solve) \ +_(aten, lu_unpack) \ +_(aten, mH) \ +_(aten, mT) \ +_(aten, margin_ranking_loss) \ +_(aten, masked_fill) \ +_(aten, masked_fill_) \ +_(aten, masked_scatter) \ +_(aten, masked_scatter_) \ +_(aten, masked_scatter_backward) \ +_(aten, masked_select) \ +_(aten, masked_select_backward) \ +_(aten, matmul) \ +_(aten, matmul_backward) \ +_(aten, matrix_H) \ +_(aten, matrix_exp) \ +_(aten, matrix_exp_backward) \ +_(aten, matrix_power) \ +_(aten, max) \ +_(aten, max_pool1d) \ +_(aten, max_pool1d_with_indices) \ +_(aten, max_pool2d) \ +_(aten, max_pool2d_backward) \ +_(aten, max_pool2d_with_indices) \ +_(aten, max_pool2d_with_indices_backward) \ +_(aten, max_pool3d) \ +_(aten, max_pool3d_with_indices) \ +_(aten, max_pool3d_with_indices_backward) \ +_(aten, max_unpool2d) \ +_(aten, max_unpool3d) \ +_(aten, maximum) \ +_(aten, mean) \ +_(aten, median) \ +_(aten, meshgrid) \ +_(aten, min) \ +_(aten, minimum) \ +_(aten, miopen_batch_norm) \ +_(aten, miopen_batch_norm_backward) \ +_(aten, miopen_convolution) \ +_(aten, miopen_convolution_add_relu) \ +_(aten, miopen_convolution_relu) \ +_(aten, miopen_convolution_transpose) \ +_(aten, miopen_depthwise_convolution) \ +_(aten, miopen_rnn) \ +_(aten, miopen_rnn_backward) \ +_(aten, mish) \ +_(aten, mish_) \ +_(aten, mish_backward) \ +_(aten, mkldnn_adaptive_avg_pool2d) \ +_(aten, mkldnn_adaptive_avg_pool2d_backward) \ +_(aten, mkldnn_convolution) \ +_(aten, mkldnn_linear) \ +_(aten, mkldnn_linear_backward) \ +_(aten, mkldnn_linear_backward_input) \ +_(aten, mkldnn_linear_backward_weights) \ +_(aten, mkldnn_max_pool2d) \ +_(aten, mkldnn_max_pool2d_backward) \ +_(aten, mkldnn_max_pool3d) \ +_(aten, mkldnn_max_pool3d_backward) \ +_(aten, mkldnn_reorder_conv2d_weight) \ +_(aten, mkldnn_reorder_conv3d_weight) \ +_(aten, mkldnn_rnn_layer) \ +_(aten, mkldnn_rnn_layer_backward) \ +_(aten, mm) \ +_(aten, mode) \ +_(aten, moveaxis) \ +_(aten, movedim) \ +_(aten, mps_convolution_backward) \ +_(aten, mps_convolution_transpose_backward) \ +_(aten, mse_loss) \ +_(aten, mse_loss_backward) \ +_(aten, msort) \ +_(aten, mul) \ +_(aten, mul_) \ +_(aten, multi_margin_loss) \ +_(aten, multi_margin_loss_backward) \ +_(aten, multilabel_margin_loss) \ +_(aten, multilabel_margin_loss_backward) \ +_(aten, multilabel_margin_loss_forward) \ +_(aten, multinomial) \ +_(aten, multiply) \ +_(aten, multiply_) \ +_(aten, mv) \ +_(aten, mvlgamma) \ +_(aten, mvlgamma_) \ +_(aten, nan_to_num) \ +_(aten, nan_to_num_) \ +_(aten, nanmean) \ +_(aten, nanmedian) \ +_(aten, nanquantile) \ +_(aten, nansum) \ +_(aten, narrow) \ +_(aten, narrow_copy) \ +_(aten, native_batch_norm) \ +_(aten, native_batch_norm_backward) \ +_(aten, native_channel_shuffle) \ +_(aten, native_dropout) \ +_(aten, native_dropout_backward) \ +_(aten, native_group_norm) \ +_(aten, native_group_norm_backward) \ +_(aten, native_layer_norm) \ +_(aten, native_layer_norm_backward) \ +_(aten, native_norm) \ +_(aten, ne) \ +_(aten, ne_) \ +_(aten, neg) \ +_(aten, neg_) \ +_(aten, negative) \ +_(aten, negative_) \ +_(aten, nested_to_padded_tensor) \ +_(aten, new_empty) \ +_(aten, new_empty_strided) \ +_(aten, new_full) \ +_(aten, new_ones) \ +_(aten, new_zeros) \ +_(aten, nextafter) \ +_(aten, nextafter_) \ +_(aten, nll_loss) \ +_(aten, nll_loss2d) \ +_(aten, nll_loss2d_backward) \ +_(aten, nll_loss2d_forward) \ +_(aten, nll_loss_backward) \ +_(aten, nll_loss_forward) \ +_(aten, nll_loss_nd) \ +_(aten, nonzero) \ +_(aten, nonzero_numpy) \ +_(aten, nonzero_static) \ +_(aten, norm) \ +_(aten, norm_except_dim) \ +_(aten, normal) \ +_(aten, normal_) \ +_(aten, normal_functional) \ +_(aten, not_equal) \ +_(aten, not_equal_) \ +_(aten, nuclear_norm) \ +_(aten, numpy_T) \ +_(aten, one_hot) \ +_(aten, ones) \ +_(aten, ones_like) \ +_(aten, orgqr) \ +_(aten, ormqr) \ +_(aten, outer) \ +_(aten, output_nr) \ +_(aten, pad) \ +_(aten, pad_sequence) \ +_(aten, pairwise_distance) \ +_(aten, pdist) \ +_(aten, permute) \ +_(aten, permute_copy) \ +_(aten, pin_memory) \ +_(aten, pinverse) \ +_(aten, pixel_shuffle) \ +_(aten, pixel_unshuffle) \ +_(aten, poisson) \ +_(aten, poisson_nll_loss) \ +_(aten, polar) \ +_(aten, polygamma) \ +_(aten, polygamma_) \ +_(aten, positive) \ +_(aten, pow) \ +_(aten, pow_) \ +_(aten, prelu) \ +_(aten, prod) \ +_(aten, promote_types) \ +_(aten, put) \ +_(aten, put_) \ +_(aten, q_per_channel_axis) \ +_(aten, q_per_channel_scales) \ +_(aten, q_per_channel_zero_points) \ +_(aten, q_scale) \ +_(aten, q_zero_point) \ +_(aten, qr) \ +_(aten, qscheme) \ +_(aten, quantile) \ +_(aten, quantize_per_channel) \ +_(aten, quantize_per_tensor) \ +_(aten, quantize_per_tensor_dynamic) \ +_(aten, quantized_batch_norm) \ +_(aten, quantized_gru_cell) \ +_(aten, quantized_lstm_cell) \ +_(aten, quantized_max_pool1d) \ +_(aten, quantized_max_pool2d) \ +_(aten, quantized_max_pool3d) \ +_(aten, quantized_rnn_relu_cell) \ +_(aten, quantized_rnn_tanh_cell) \ +_(aten, rad2deg) \ +_(aten, rad2deg_) \ +_(aten, rand) \ +_(aten, rand_like) \ +_(aten, randint) \ +_(aten, randint_like) \ +_(aten, randn) \ +_(aten, randn_like) \ +_(aten, random) \ +_(aten, random_) \ +_(aten, randperm) \ +_(aten, range) \ +_(aten, ravel) \ +_(aten, real) \ +_(aten, reciprocal) \ +_(aten, reciprocal_) \ +_(aten, record_stream) \ +_(aten, refine_names) \ +_(aten, reflection_pad1d) \ +_(aten, reflection_pad1d_backward) \ +_(aten, reflection_pad2d) \ +_(aten, reflection_pad2d_backward) \ +_(aten, reflection_pad3d) \ +_(aten, reflection_pad3d_backward) \ +_(aten, relu) \ +_(aten, relu6) \ +_(aten, relu6_) \ +_(aten, relu_) \ +_(aten, remainder) \ +_(aten, remainder_) \ +_(aten, rename) \ +_(aten, rename_) \ +_(aten, renorm) \ +_(aten, renorm_) \ +_(aten, repeat) \ +_(aten, repeat_interleave) \ +_(aten, replication_pad1d) \ +_(aten, replication_pad1d_backward) \ +_(aten, replication_pad2d) \ +_(aten, replication_pad2d_backward) \ +_(aten, replication_pad3d) \ +_(aten, replication_pad3d_backward) \ +_(aten, requires_grad) \ +_(aten, requires_grad_) \ +_(aten, reshape) \ +_(aten, reshape_as) \ +_(aten, resize) \ +_(aten, resize_) \ +_(aten, resize_as) \ +_(aten, resize_as_) \ +_(aten, resize_as_sparse) \ +_(aten, resize_as_sparse_) \ +_(aten, resolve_conj) \ +_(aten, resolve_neg) \ +_(aten, result_type) \ +_(aten, retain_grad) \ +_(aten, retains_grad) \ +_(aten, rms_norm) \ +_(aten, rnn_relu) \ +_(aten, rnn_relu_cell) \ +_(aten, rnn_tanh) \ +_(aten, rnn_tanh_cell) \ +_(aten, roll) \ +_(aten, rot90) \ +_(aten, round) \ +_(aten, round_) \ +_(aten, row_indices) \ +_(aten, row_indices_copy) \ +_(aten, row_stack) \ +_(aten, rrelu) \ +_(aten, rrelu_) \ +_(aten, rrelu_with_noise) \ +_(aten, rrelu_with_noise_) \ +_(aten, rrelu_with_noise_backward) \ +_(aten, rrelu_with_noise_functional) \ +_(aten, rshift) \ +_(aten, rsqrt) \ +_(aten, rsqrt_) \ +_(aten, rsub) \ +_(aten, scalar_tensor) \ +_(aten, scaled_dot_product_attention) \ +_(aten, scatter) \ +_(aten, scatter_) \ +_(aten, scatter_add) \ +_(aten, scatter_add_) \ +_(aten, scatter_reduce) \ +_(aten, scatter_reduce_) \ +_(aten, searchsorted) \ +_(aten, segment_reduce) \ +_(aten, select) \ +_(aten, select_backward) \ +_(aten, select_copy) \ +_(aten, select_scatter) \ +_(aten, selu) \ +_(aten, selu_) \ +_(aten, set) \ +_(aten, set_) \ +_(aten, set_data) \ +_(aten, sgn) \ +_(aten, sgn_) \ +_(aten, sigmoid) \ +_(aten, sigmoid_) \ +_(aten, sigmoid_backward) \ +_(aten, sign) \ +_(aten, sign_) \ +_(aten, signbit) \ +_(aten, silu) \ +_(aten, silu_) \ +_(aten, silu_backward) \ +_(aten, sin) \ +_(aten, sin_) \ +_(aten, sinc) \ +_(aten, sinc_) \ +_(aten, sinh) \ +_(aten, sinh_) \ +_(aten, size) \ +_(aten, slice) \ +_(aten, slice_backward) \ +_(aten, slice_copy) \ +_(aten, slice_inverse) \ +_(aten, slice_scatter) \ +_(aten, slogdet) \ +_(aten, slow_conv3d) \ +_(aten, slow_conv3d_forward) \ +_(aten, slow_conv_dilated2d) \ +_(aten, slow_conv_dilated3d) \ +_(aten, slow_conv_transpose2d) \ +_(aten, slow_conv_transpose3d) \ +_(aten, smm) \ +_(aten, smooth_l1_loss) \ +_(aten, smooth_l1_loss_backward) \ +_(aten, soft_margin_loss) \ +_(aten, soft_margin_loss_backward) \ +_(aten, softmax) \ +_(aten, softplus) \ +_(aten, softplus_backward) \ +_(aten, softshrink) \ +_(aten, softshrink_backward) \ +_(aten, sort) \ +_(aten, sparse_bsc_tensor) \ +_(aten, sparse_bsr_tensor) \ +_(aten, sparse_compressed_tensor) \ +_(aten, sparse_coo_tensor) \ +_(aten, sparse_csc_tensor) \ +_(aten, sparse_csr_tensor) \ +_(aten, sparse_dim) \ +_(aten, sparse_mask) \ +_(aten, sparse_resize) \ +_(aten, sparse_resize_) \ +_(aten, sparse_resize_and_clear) \ +_(aten, sparse_resize_and_clear_) \ +_(aten, sparse_sampled_addmm) \ +_(aten, special_airy_ai) \ +_(aten, special_bessel_j0) \ +_(aten, special_bessel_j1) \ +_(aten, special_bessel_y0) \ +_(aten, special_bessel_y1) \ +_(aten, special_chebyshev_polynomial_t) \ +_(aten, special_chebyshev_polynomial_u) \ +_(aten, special_chebyshev_polynomial_v) \ +_(aten, special_chebyshev_polynomial_w) \ +_(aten, special_digamma) \ +_(aten, special_entr) \ +_(aten, special_erf) \ +_(aten, special_erfc) \ +_(aten, special_erfcx) \ +_(aten, special_erfinv) \ +_(aten, special_exp2) \ +_(aten, special_expit) \ +_(aten, special_expm1) \ +_(aten, special_gammainc) \ +_(aten, special_gammaincc) \ +_(aten, special_gammaln) \ +_(aten, special_hermite_polynomial_h) \ +_(aten, special_hermite_polynomial_he) \ +_(aten, special_i0) \ +_(aten, special_i0e) \ +_(aten, special_i1) \ +_(aten, special_i1e) \ +_(aten, special_laguerre_polynomial_l) \ +_(aten, special_legendre_polynomial_p) \ +_(aten, special_log1p) \ +_(aten, special_log_ndtr) \ +_(aten, special_log_softmax) \ +_(aten, special_logit) \ +_(aten, special_logsumexp) \ +_(aten, special_modified_bessel_i0) \ +_(aten, special_modified_bessel_i1) \ +_(aten, special_modified_bessel_k0) \ +_(aten, special_modified_bessel_k1) \ +_(aten, special_multigammaln) \ +_(aten, special_ndtr) \ +_(aten, special_ndtri) \ +_(aten, special_polygamma) \ +_(aten, special_psi) \ +_(aten, special_round) \ +_(aten, special_scaled_modified_bessel_k0) \ +_(aten, special_scaled_modified_bessel_k1) \ +_(aten, special_shifted_chebyshev_polynomial_t) \ +_(aten, special_shifted_chebyshev_polynomial_u) \ +_(aten, special_shifted_chebyshev_polynomial_v) \ +_(aten, special_shifted_chebyshev_polynomial_w) \ +_(aten, special_sinc) \ +_(aten, special_softmax) \ +_(aten, special_spherical_bessel_j0) \ +_(aten, special_xlog1py) \ +_(aten, special_xlogy) \ +_(aten, special_zeta) \ +_(aten, split) \ +_(aten, split_copy) \ +_(aten, split_with_sizes) \ +_(aten, split_with_sizes_copy) \ +_(aten, sqrt) \ +_(aten, sqrt_) \ +_(aten, square) \ +_(aten, square_) \ +_(aten, squeeze) \ +_(aten, squeeze_) \ +_(aten, squeeze_copy) \ +_(aten, sspaddmm) \ +_(aten, stack) \ +_(aten, std) \ +_(aten, std_mean) \ +_(aten, stft) \ +_(aten, stride) \ +_(aten, sub) \ +_(aten, sub_) \ +_(aten, subtract) \ +_(aten, subtract_) \ +_(aten, sum) \ +_(aten, sum_to_size) \ +_(aten, svd) \ +_(aten, swapaxes) \ +_(aten, swapaxes_) \ +_(aten, swapdims) \ +_(aten, swapdims_) \ +_(aten, sym_constrain_range) \ +_(aten, sym_constrain_range_for_size) \ +_(aten, sym_numel) \ +_(aten, sym_size) \ +_(aten, sym_storage_offset) \ +_(aten, sym_stride) \ +_(aten, t) \ +_(aten, t_) \ +_(aten, t_copy) \ +_(aten, take) \ +_(aten, take_along_dim) \ +_(aten, tan) \ +_(aten, tan_) \ +_(aten, tanh) \ +_(aten, tanh_) \ +_(aten, tanh_backward) \ +_(aten, tensor_split) \ +_(aten, tensordot) \ +_(aten, thnn_conv2d) \ +_(aten, threshold) \ +_(aten, threshold_) \ +_(aten, threshold_backward) \ +_(aten, tile) \ +_(aten, to) \ +_(aten, to_dense) \ +_(aten, to_dense_backward) \ +_(aten, to_mkldnn) \ +_(aten, to_mkldnn_backward) \ +_(aten, to_padded_tensor) \ +_(aten, to_sparse) \ +_(aten, to_sparse_bsc) \ +_(aten, to_sparse_bsr) \ +_(aten, to_sparse_csc) \ +_(aten, to_sparse_csr) \ +_(aten, topk) \ +_(aten, trace) \ +_(aten, trace_backward) \ +_(aten, transpose) \ +_(aten, transpose_) \ +_(aten, transpose_copy) \ +_(aten, trapezoid) \ +_(aten, trapz) \ +_(aten, triangular_solve) \ +_(aten, tril) \ +_(aten, tril_) \ +_(aten, tril_indices) \ +_(aten, triplet_margin_loss) \ +_(aten, triu) \ +_(aten, triu_) \ +_(aten, triu_indices) \ +_(aten, true_divide) \ +_(aten, true_divide_) \ +_(aten, trunc) \ +_(aten, trunc_) \ +_(aten, type_as) \ +_(aten, unbind) \ +_(aten, unbind_copy) \ +_(aten, unflatten) \ +_(aten, unflatten_dense_tensors) \ +_(aten, unfold) \ +_(aten, unfold_backward) \ +_(aten, unfold_copy) \ +_(aten, uniform) \ +_(aten, uniform_) \ +_(aten, unique_consecutive) \ +_(aten, unique_dim) \ +_(aten, unique_dim_consecutive) \ +_(aten, unsafe_chunk) \ +_(aten, unsafe_split) \ +_(aten, unsafe_split_with_sizes) \ +_(aten, unsqueeze) \ +_(aten, unsqueeze_) \ +_(aten, unsqueeze_copy) \ +_(aten, upsample_bicubic2d) \ +_(aten, upsample_bicubic2d_backward) \ +_(aten, upsample_bilinear2d) \ +_(aten, upsample_bilinear2d_backward) \ +_(aten, upsample_linear1d) \ +_(aten, upsample_linear1d_backward) \ +_(aten, upsample_nearest1d) \ +_(aten, upsample_nearest1d_backward) \ +_(aten, upsample_nearest2d) \ +_(aten, upsample_nearest2d_backward) \ +_(aten, upsample_nearest3d) \ +_(aten, upsample_nearest3d_backward) \ +_(aten, upsample_trilinear3d) \ +_(aten, upsample_trilinear3d_backward) \ +_(aten, value_selecting_reduction_backward) \ +_(aten, values) \ +_(aten, values_copy) \ +_(aten, vander) \ +_(aten, var) \ +_(aten, var_mean) \ +_(aten, vdot) \ +_(aten, view) \ +_(aten, view_as) \ +_(aten, view_as_complex) \ +_(aten, view_as_complex_copy) \ +_(aten, view_as_real) \ +_(aten, view_as_real_copy) \ +_(aten, view_copy) \ +_(aten, vsplit) \ +_(aten, vstack) \ +_(aten, where) \ +_(aten, xlogy) \ +_(aten, xlogy_) \ +_(aten, zero) \ +_(aten, zero_) \ +_(aten, zeros) \ +_(aten, zeros_like) + +#define FORALL_ATTR_BASE_SYMBOLS(_) \ +_(attr, A) \ +_(attr, B) \ +_(attr, C) \ +_(attr, H) \ +_(attr, HxW) \ +_(attr, K) \ +_(attr, L) \ +_(attr, LD) \ +_(attr, LU) \ +_(attr, LU_data) \ +_(attr, LU_pivots) \ +_(attr, M) \ +_(attr, N) \ +_(attr, P) \ +_(attr, Q) \ +_(attr, R) \ +_(attr, S) \ +_(attr, U) \ +_(attr, UPLO) \ +_(attr, V) \ +_(attr, Vh) \ +_(attr, W) \ +_(attr, X) \ +_(attr, a) \ +_(attr, abs) \ +_(attr, accumulate) \ +_(attr, accumulate_matches) \ +_(attr, activation) \ +_(attr, addends) \ +_(attr, adjoint) \ +_(attr, alg_id) \ +_(attr, algorithm) \ +_(attr, alibi_slopes) \ +_(attr, align_corners) \ +_(attr, align_to_window) \ +_(attr, allow_tf32) \ +_(attr, alpha) \ +_(attr, amsgrad) \ +_(attr, anchor) \ +_(attr, angle) \ +_(attr, any) \ +_(attr, api_name) \ +_(attr, append) \ +_(attr, approximate) \ +_(attr, arg1) \ +_(attr, arg2) \ +_(attr, arg3) \ +_(attr, arg_out) \ +_(attr, assert_msg) \ +_(attr, assume_unique) \ +_(attr, atol) \ +_(attr, attn_bias) \ +_(attr, attn_mask) \ +_(attr, average_attn_weights) \ +_(attr, averaging_const) \ +_(attr, aweights) \ +_(attr, axis) \ +_(attr, axis0) \ +_(attr, axis1) \ +_(attr, b) \ +_(attr, b_hh) \ +_(attr, b_ih) \ +_(attr, bag_size) \ +_(attr, base) \ +_(attr, batch1) \ +_(attr, batch2) \ +_(attr, batch_dim) \ +_(attr, batch_first) \ +_(attr, batch_size) \ +_(attr, batch_sizes) \ +_(attr, benchmark) \ +_(attr, beta) \ +_(attr, beta1) \ +_(attr, beta2) \ +_(attr, bias) \ +_(attr, bias_defined) \ +_(attr, bias_g) \ +_(attr, bias_requires_grad) \ +_(attr, bias_sizes) \ +_(attr, bidirectional) \ +_(attr, bin_edges) \ +_(attr, bins) \ +_(attr, bit_width) \ +_(attr, blank) \ +_(attr, block_size) \ +_(attr, blocksize) \ +_(attr, boundaries) \ +_(attr, buffer) \ +_(attr, ccol_indices) \ +_(attr, cdim) \ +_(attr, cdist) \ +_(attr, ceil_mode) \ +_(attr, cell_state_fwd) \ +_(attr, center) \ +_(attr, ch_axis) \ +_(attr, check_errors) \ +_(attr, check_pinning) \ +_(attr, chunks) \ +_(attr, coalesced) \ +_(attr, coefficients) \ +_(attr, col) \ +_(attr, col_indices) \ +_(attr, col_offsets) \ +_(attr, col_offsets_hh) \ +_(attr, col_offsets_ih) \ +_(attr, compressed_A) \ +_(attr, compressed_idx) \ +_(attr, compressed_indices) \ +_(attr, compressed_indices_dtype) \ +_(attr, compute_log_sumexp) \ +_(attr, compute_mode) \ +_(attr, compute_uv) \ +_(attr, compute_v) \ +_(attr, condition) \ +_(attr, copy) \ +_(attr, correction) \ +_(attr, count) \ +_(attr, count_include_pad) \ +_(attr, counts) \ +_(attr, cpu_dtype) \ +_(attr, cpu_enabled) \ +_(attr, cpu_nested_shape_example) \ +_(attr, create_graph) \ +_(attr, crow_indices) \ +_(attr, cu_seqlens_k) \ +_(attr, cu_seqlens_q) \ +_(attr, cuda_dtype) \ +_(attr, cuda_enabled) \ +_(attr, cudnn_enable) \ +_(attr, cudnn_enabled) \ +_(attr, cum_seq_k) \ +_(attr, cum_seq_q) \ +_(attr, custom_mask_type) \ +_(attr, cx) \ +_(attr, cx_) \ +_(attr, cx_tmp) \ +_(attr, cy) \ +_(attr, cy_) \ +_(attr, d) \ +_(attr, dampening) \ +_(attr, data) \ +_(attr, decimals) \ +_(attr, delta) \ +_(attr, dense) \ +_(attr, dense_B) \ +_(attr, dense_dim) \ +_(attr, density) \ +_(attr, dep_token) \ +_(attr, descending) \ +_(attr, destination) \ +_(attr, deterministic) \ +_(attr, device) \ +_(attr, device_index) \ +_(attr, dgrad_glu) \ +_(attr, diagonal) \ +_(attr, diagonals) \ +_(attr, dilation) \ +_(attr, dim) \ +_(attr, dim0) \ +_(attr, dim1) \ +_(attr, dim2) \ +_(attr, dimension) \ +_(attr, dims) \ +_(attr, dims_other) \ +_(attr, dims_self) \ +_(attr, divisor_override) \ +_(attr, downscale_factor) \ +_(attr, driver) \ +_(attr, dropout) \ +_(attr, dropout_mask) \ +_(attr, dropout_p) \ +_(attr, dropout_seed) \ +_(attr, dropout_state) \ +_(attr, dst) \ +_(attr, dtype) \ +_(attr, dual) \ +_(attr, dummy) \ +_(attr, dx) \ +_(attr, edge_order) \ +_(attr, eigenvalues) \ +_(attr, eigenvectors) \ +_(attr, eigvals) \ +_(attr, eigvecs) \ +_(attr, element) \ +_(attr, elements) \ +_(attr, ellipsis_idx) \ +_(attr, embed_dim) \ +_(attr, enable_gqa) \ +_(attr, end) \ +_(attr, end_dim) \ +_(attr, eps) \ +_(attr, epsilon) \ +_(attr, equal_nan) \ +_(attr, equation) \ +_(attr, exp_avg_sqs) \ +_(attr, exp_avgs) \ +_(attr, expand1) \ +_(attr, expand2) \ +_(attr, expand3) \ +_(attr, exponent) \ +_(attr, exponential_average_factor) \ +_(attr, fake_quant_enabled) \ +_(attr, fake_quant_on) \ +_(attr, ffn_bias_1) \ +_(attr, ffn_bias_2) \ +_(attr, ffn_weight_1) \ +_(attr, ffn_weight_2) \ +_(attr, filename) \ +_(attr, fill) \ +_(attr, fill_value) \ +_(attr, flat) \ +_(attr, forward) \ +_(attr, found_inf) \ +_(attr, from) \ +_(attr, from_) \ +_(attr, full) \ +_(attr, full_matrices) \ +_(attr, fuse_transform_0213) \ +_(attr, fweights) \ +_(attr, g) \ +_(attr, gO) \ +_(attr, generator) \ +_(attr, ggI) \ +_(attr, ggW) \ +_(attr, ggb) \ +_(attr, glu) \ +_(attr, grad) \ +_(attr, grad_bias) \ +_(attr, grad_cy) \ +_(attr, grad_factor) \ +_(attr, grad_glu) \ +_(attr, grad_hy) \ +_(attr, grad_in) \ +_(attr, grad_input) \ +_(attr, grad_input_mask) \ +_(attr, grad_out) \ +_(attr, grad_out_) \ +_(attr, grad_output) \ +_(attr, grad_scale) \ +_(attr, grad_w) \ +_(attr, grad_weight) \ +_(attr, grad_x) \ +_(attr, grad_y) \ +_(attr, gradient) \ +_(attr, grads) \ +_(attr, grid) \ +_(attr, group) \ +_(attr, groups) \ +_(attr, growth_interval) \ +_(attr, growth_tracker) \ +_(attr, half_to_float) \ +_(attr, has_bias) \ +_(attr, has_biases) \ +_(attr, hermitian) \ +_(attr, hidden_bias) \ +_(attr, hidden_gates) \ +_(attr, hidden_size) \ +_(attr, high) \ +_(attr, hist) \ +_(attr, hop_length) \ +_(attr, hx) \ +_(attr, hx_) \ +_(attr, hy_) \ +_(attr, i1) \ +_(attr, i2) \ +_(attr, i3) \ +_(attr, ignore_index) \ +_(attr, imag) \ +_(attr, impl_index) \ +_(attr, implicit) \ +_(attr, in_features) \ +_(attr, include_last_offset) \ +_(attr, include_self) \ +_(attr, increasing) \ +_(attr, ind) \ +_(attr, index) \ +_(attr, index_dtype) \ +_(attr, indexing) \ +_(attr, indices) \ +_(attr, info) \ +_(attr, initial) \ +_(attr, innerKTiles) \ +_(attr, inp) \ +_(attr, input) \ +_(attr, input1) \ +_(attr, input2) \ +_(attr, input3) \ +_(attr, input_bias) \ +_(attr, input_dtype) \ +_(attr, input_g) \ +_(attr, input_gates) \ +_(attr, input_lengths) \ +_(attr, input_scale) \ +_(attr, input_size) \ +_(attr, input_sizes) \ +_(attr, input_zero_point) \ +_(attr, inputs) \ +_(attr, interpolation) \ +_(attr, interpolation_mode) \ +_(attr, inv_scale) \ +_(attr, inverse) \ +_(attr, invert) \ +_(attr, invstd) \ +_(attr, is_causal) \ +_(attr, is_coalesced) \ +_(attr, is_crow) \ +_(attr, is_first_step) \ +_(attr, is_matrix) \ +_(attr, is_result) \ +_(attr, is_target) \ +_(attr, k) \ +_(attr, keepdim) \ +_(attr, kernel_size) \ +_(attr, key) \ +_(attr, label_smoothing) \ +_(attr, lambd) \ +_(attr, largest) \ +_(attr, last_dim_size) \ +_(attr, layersOutputs) \ +_(attr, layout) \ +_(attr, left) \ +_(attr, length) \ +_(attr, lengths) \ +_(attr, level) \ +_(attr, like) \ +_(attr, list) \ +_(attr, log_alpha) \ +_(attr, log_input) \ +_(attr, log_probs) \ +_(attr, log_target) \ +_(attr, logabsdet) \ +_(attr, logsumexp) \ +_(attr, low) \ +_(attr, lower) \ +_(attr, lr) \ +_(attr, lr_decay) \ +_(attr, ltm) \ +_(attr, m) \ +_(attr, mantissa) \ +_(attr, margin) \ +_(attr, mask) \ +_(attr, mask_check) \ +_(attr, mask_type) \ +_(attr, masked_grad) \ +_(attr, mat) \ +_(attr, mat1) \ +_(attr, mat1_meta) \ +_(attr, mat2) \ +_(attr, matrices) \ +_(attr, max) \ +_(attr, max_exp_avg_sqs) \ +_(attr, max_k) \ +_(attr, max_lengths) \ +_(attr, max_norm) \ +_(attr, max_q) \ +_(attr, max_seqlen) \ +_(attr, max_seqlen_k) \ +_(attr, max_seqlen_q) \ +_(attr, max_size) \ +_(attr, max_val) \ +_(attr, max_values) \ +_(attr, maximize) \ +_(attr, maximum_indices) \ +_(attr, maxnorm) \ +_(attr, mean) \ +_(attr, median) \ +_(attr, memory_format) \ +_(attr, meta) \ +_(attr, min) \ +_(attr, min_indices) \ +_(attr, min_seqlen) \ +_(attr, min_val) \ +_(attr, minlength) \ +_(attr, mode) \ +_(attr, momentum) \ +_(attr, momentum_buffer_list) \ +_(attr, n) \ +_(attr, n_bins) \ +_(attr, n_fft) \ +_(attr, names) \ +_(attr, nan) \ +_(attr, need_weights) \ +_(attr, neg_log_likelihood) \ +_(attr, negative) \ +_(attr, negative_slope) \ +_(attr, neginf) \ +_(attr, nested_size) \ +_(attr, nested_strides) \ +_(attr, nesterov) \ +_(attr, new_data) \ +_(attr, nnz) \ +_(attr, noise) \ +_(attr, non_blocking) \ +_(attr, norm) \ +_(attr, norm_bias_1) \ +_(attr, norm_bias_2) \ +_(attr, norm_first) \ +_(attr, norm_type) \ +_(attr, norm_weight_1) \ +_(attr, norm_weight_2) \ +_(attr, normalization) \ +_(attr, normalized) \ +_(attr, normalized_shape) \ +_(attr, normalized_shape_ndim) \ +_(attr, nt_example) \ +_(attr, num_chunks) \ +_(attr, num_classes) \ +_(attr, num_generated) \ +_(attr, num_groups) \ +_(attr, num_head) \ +_(attr, num_heads) \ +_(attr, num_layers) \ +_(attr, num_parallel) \ +_(attr, num_samples) \ +_(attr, num_splits_key) \ +_(attr, num_weights) \ +_(attr, numel) \ +_(attr, observer_on) \ +_(attr, offs) \ +_(attr, offset) \ +_(attr, offset2bag) \ +_(attr, offsets) \ +_(attr, onesided) \ +_(attr, ord) \ +_(attr, order) \ +_(attr, other) \ +_(attr, out) \ +_(attr, out0) \ +_(attr, out1) \ +_(attr, out2) \ +_(attr, out3) \ +_(attr, out4) \ +_(attr, out5) \ +_(attr, out6) \ +_(attr, out_channel) \ +_(attr, out_dim) \ +_(attr, out_dtype) \ +_(attr, out_features) \ +_(attr, out_int32) \ +_(attr, outdim) \ +_(attr, output) \ +_(attr, output_mask) \ +_(attr, output_padding) \ +_(attr, output_scale) \ +_(attr, output_size) \ +_(attr, output_zero_point) \ +_(attr, p) \ +_(attr, packed) \ +_(attr, packed_hh) \ +_(attr, packed_ih) \ +_(attr, packed_weight) \ +_(attr, packed_weights) \ +_(attr, pad) \ +_(attr, pad_mode) \ +_(attr, padded) \ +_(attr, padding) \ +_(attr, padding_idx) \ +_(attr, padding_mode) \ +_(attr, padding_side) \ +_(attr, padding_value) \ +_(attr, params) \ +_(attr, path) \ +_(attr, pdist) \ +_(attr, per_row_fake_quant) \ +_(attr, per_sample_weights) \ +_(attr, periodic) \ +_(attr, philox_offset) \ +_(attr, philox_seed) \ +_(attr, physical_layout) \ +_(attr, pin_memory) \ +_(attr, pivot) \ +_(attr, pivots) \ +_(attr, plain_idx) \ +_(attr, plain_indices) \ +_(attr, pos_weight) \ +_(attr, posinf) \ +_(attr, positive) \ +_(attr, pow) \ +_(attr, prepend) \ +_(attr, primal) \ +_(attr, prob) \ +_(attr, proj_bias) \ +_(attr, proj_size) \ +_(attr, proj_weight) \ +_(attr, q) \ +_(attr, qGroupSize) \ +_(attr, qScale) \ +_(attr, qScaleAndZeros) \ +_(attr, qZeros) \ +_(attr, qkv) \ +_(attr, qkv_bias) \ +_(attr, qkv_weight) \ +_(attr, qtensor) \ +_(attr, quant_max) \ +_(attr, quant_min) \ +_(attr, quasi) \ +_(attr, query) \ +_(attr, r) \ +_(attr, ragged_idx) \ +_(attr, random_samples) \ +_(attr, range) \ +_(attr, rank) \ +_(attr, ratio) \ +_(attr, rcond) \ +_(attr, real) \ +_(attr, reduce) \ +_(attr, reduce_range) \ +_(attr, reduction) \ +_(attr, repeats) \ +_(attr, replacement) \ +_(attr, requires_grad) \ +_(attr, reserve) \ +_(attr, reserveSpace) \ +_(attr, reservedSpace) \ +_(attr, residuals) \ +_(attr, result) \ +_(attr, retain_graph) \ +_(attr, return_complex) \ +_(attr, return_counts) \ +_(attr, return_debug_mask) \ +_(attr, return_inverse) \ +_(attr, reverse) \ +_(attr, right) \ +_(attr, rng_state) \ +_(attr, rounding_mode) \ +_(attr, row) \ +_(attr, row_indices) \ +_(attr, rstd) \ +_(attr, rtol) \ +_(attr, running_max) \ +_(attr, running_mean) \ +_(attr, running_min) \ +_(attr, running_var) \ +_(attr, s) \ +_(attr, save_invstd) \ +_(attr, save_mean) \ +_(attr, save_var) \ +_(attr, save_var_transform) \ +_(attr, saved_g) \ +_(attr, saved_norms) \ +_(attr, saved_v) \ +_(attr, scalar) \ +_(attr, scalar1) \ +_(attr, scalar2) \ +_(attr, scalars) \ +_(attr, scale) \ +_(attr, scale_a) \ +_(attr, scale_b) \ +_(attr, scale_backoff_factor) \ +_(attr, scale_factors) \ +_(attr, scale_grad_by_freq) \ +_(attr, scale_growth_factor) \ +_(attr, scale_hh) \ +_(attr, scale_ih) \ +_(attr, scale_result) \ +_(attr, scales) \ +_(attr, scales_d) \ +_(attr, scales_h) \ +_(attr, scales_w) \ +_(attr, scales_zeros) \ +_(attr, sections) \ +_(attr, seed) \ +_(attr, self) \ +_(attr, self_is_result) \ +_(attr, self_num_batch_dims) \ +_(attr, self_or_result) \ +_(attr, self_sizes) \ +_(attr, seqlen_k) \ +_(attr, sequences) \ +_(attr, seqused_k) \ +_(attr, shape) \ +_(attr, shared) \ +_(attr, shared_storage_dqdkdv) \ +_(attr, shifts) \ +_(attr, side) \ +_(attr, sigma) \ +_(attr, sign) \ +_(attr, singular_values) \ +_(attr, size) \ +_(attr, sizes) \ +_(attr, skip_first) \ +_(attr, sobolstate) \ +_(attr, solution) \ +_(attr, some) \ +_(attr, sorted) \ +_(attr, sorted_sequence) \ +_(attr, sorter) \ +_(attr, source) \ +_(attr, spacing) \ +_(attr, sparse) \ +_(attr, sparse_dim) \ +_(attr, sparse_grad) \ +_(attr, split_k) \ +_(attr, split_k_mode) \ +_(attr, split_size) \ +_(attr, split_sizes) \ +_(attr, src) \ +_(attr, stable) \ +_(attr, start) \ +_(attr, start_dim) \ +_(attr, state_steps) \ +_(attr, state_sums) \ +_(attr, std) \ +_(attr, step) \ +_(attr, steps) \ +_(attr, storage_offset) \ +_(attr, stride) \ +_(attr, sum_S) \ +_(attr, sum_dy) \ +_(attr, sum_dy_xmu) \ +_(attr, sumdim) \ +_(attr, swap) \ +_(attr, symmetric_quant) \ +_(attr, t) \ +_(attr, tangent) \ +_(attr, target) \ +_(attr, target_lengths) \ +_(attr, targets) \ +_(attr, tau) \ +_(attr, tensor) \ +_(attr, tensor1) \ +_(attr, tensor2) \ +_(attr, tensor_indices_or_sections) \ +_(attr, tensors) \ +_(attr, tensors1) \ +_(attr, test_element) \ +_(attr, test_elements) \ +_(attr, the_template) \ +_(attr, theta) \ +_(attr, thread_masks) \ +_(attr, threshold) \ +_(attr, to) \ +_(attr, tol) \ +_(attr, total) \ +_(attr, total_L) \ +_(attr, total_length) \ +_(attr, total_weight) \ +_(attr, train) \ +_(attr, training) \ +_(attr, transpose) \ +_(attr, transpose_result) \ +_(attr, transposed) \ +_(attr, type1) \ +_(attr, type2) \ +_(attr, unbiased) \ +_(attr, unitriangular) \ +_(attr, unpack_data) \ +_(attr, unpack_pivots) \ +_(attr, unroll_dim) \ +_(attr, unsafe) \ +_(attr, unused) \ +_(attr, update) \ +_(attr, upper) \ +_(attr, upscale_factor) \ +_(attr, use_cutlass) \ +_(attr, use_fast_accum) \ +_(attr, use_gelu) \ +_(attr, use_input_stats) \ +_(attr, v) \ +_(attr, value) \ +_(attr, values) \ +_(attr, var) \ +_(attr, vec) \ +_(attr, vec1) \ +_(attr, vec2) \ +_(attr, w_hh) \ +_(attr, w_ih) \ +_(attr, weight) \ +_(attr, weight0) \ +_(attr, weight1) \ +_(attr, weight2) \ +_(attr, weight3) \ +_(attr, weight4) \ +_(attr, weight_arr) \ +_(attr, weight_buf) \ +_(attr, weight_decay) \ +_(attr, weight_g) \ +_(attr, weight_scale) \ +_(attr, weight_stride0) \ +_(attr, weight_zero_point) \ +_(attr, weights) \ +_(attr, win_length) \ +_(attr, window) \ +_(attr, window_length) \ +_(attr, window_size) \ +_(attr, window_size_left) \ +_(attr, window_size_right) \ +_(attr, with_replacement) \ +_(attr, workspace) \ +_(attr, wrap) \ +_(attr, x) \ +_(attr, x1) \ +_(attr, x2) \ +_(attr, y) \ +_(attr, z) \ +_(attr, z_state) \ +_(attr, zero_infinity) \ +_(attr, zero_point) \ +_(attr, zero_point_hh) \ +_(attr, zero_point_ih) \ +_(attr, zero_points) diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/blob.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/blob.h new file mode 100644 index 0000000000000000000000000000000000000000..251da65e0896fbea56af545f5d90c731911ab9b8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/blob.h @@ -0,0 +1,204 @@ +#pragma once + +#include + +#include +#include +#include + +namespace caffe2 { + +class Tensor; + +/** + * @brief Blob is a general container that hosts a typed pointer. + * + * A Blob hosts a pointer as well as its type, and takes charge of deleting it + * properly when the blob is deallocated or re-allocated with a new type. A blob + * could contain anything, although the most common case is to contain a Tensor. + */ +class TORCH_API Blob final : public c10::intrusive_ptr_target { + public: + /** + * Initializes an empty Blob. + */ + Blob() noexcept = default; + ~Blob() override { + Reset(); + } + + Blob(Blob&& other) noexcept : Blob() { + swap(other); + } + + Blob& operator=(Blob&& other) noexcept { + Blob(std::move(other)).swap(*this); + return *this; + } + + /** + * Checks if the content stored in the blob is of type T. + */ + template + bool IsType() const noexcept { + return meta_.Match(); + } + + /** + * Returns the meta info of the blob. + */ + const TypeMeta meta() const noexcept { + return meta_; + } + + /** + * Returns a printable typename of the blob. + */ + std::string_view TypeName() const noexcept { + return meta_.name(); + } + + /** + * @brief Gets the const reference of the stored object. The code checks if + * the stored object is of the desired type. + */ + // TODO(jerryzh): add a Get(c10::DeviceType) function? + template + const T& Get() const { + TORCH_INTERNAL_ASSERT( + IsType(), + "wrong type for the Blob instance. Blob contains ", + meta_.name(), + " while caller expects ", + TypeMeta::TypeName()); + // TODO: after we add Get(c10::DeviceType) + // and changed all the callsites, we can add + // a static assert here to enforce T != Tensor + return *static_cast(pointer_); + } + + const void* GetRaw() const noexcept { + return pointer_; + } + void* GetRaw() noexcept { + return pointer_; + } + + /** + * @brief Gets a mutable pointer to the stored object. + * + * If the current object is not of the right type, a new object is created + * and the old object is freed. Note that type T should have a default + * constructor. Otherwise, create the object yourself first, and use + * Reset(). + */ + template + T* GetMutable() { + static_assert( + std::is_default_constructible_v, + "GetMutable can't be called with non-default-constructible types. " + "Try using specialized methods"); + if (IsType()) { + return static_cast(pointer_); + } else { + // TODO Re-enable logging + // VLOG(1) << "Create new mutable object " << TypeMeta::TypeName(); + return Reset(new T()); + } + } + + template + T* GetMutableOrNull() { + if (IsType()) { + return static_cast(pointer_); + } else { + return nullptr; + } + } + + /** + * Sets the underlying object to the allocated one. The Blob then takes over + * the ownership of the passed in pointer. If there is already an object in + * the Blob, the old object is freed. + * + * This is used when the underlying class T does not have a default ctor, or + * complex initializations needs to be done outside the blob. + */ + template + T* Reset(T* allocated) { + free_(); + meta_ = TypeMeta::Make(); + pointer_ = static_cast(allocated); + has_ownership_ = true; + return allocated; + } + + /** + * Sets the underlying object to the allocated one, but does not take over + * the ownership of the passed in pointer. If there is already an object in + * the Blob, the old object is freed. + * + * Unlike Reset, this does not take over the ownership of the pointer and the + * caller is responsible for making sure that the lifetime of the allocated + * blob outlasts the lifetime of any access to this blob, until another Reset + * call is made or the blob is destructed. + */ + template + std::remove_const_t* ShareExternal( + std::remove_const_t* allocated) { + return static_cast(ShareExternal( + static_cast(allocated), + TypeMeta::Make>())); + } + + void* ShareExternal(void* allocated, const TypeMeta meta) { + free_(); + meta_ = meta; + pointer_ = allocated; + has_ownership_ = false; + return allocated; + } + + /** + * Resets the Blob to an empty one. + */ + void Reset() { + free_(); + pointer_ = nullptr; + meta_ = TypeMeta(); + has_ownership_ = false; + } + + /** + * @brief Swaps the underlying storage of two blobs. + */ + void swap(Blob& rhs) noexcept { + using std::swap; + swap(meta_, rhs.meta_); + swap(pointer_, rhs.pointer_); + swap(has_ownership_, rhs.has_ownership_); + } + + private: + void free_() { + if (has_ownership_ && pointer_ != nullptr) { + (*meta_.deleteFn())(pointer_); + } + } + + TypeMeta meta_; + void* pointer_{nullptr}; + bool has_ownership_{false}; + + C10_DISABLE_COPY_AND_ASSIGN(Blob); +}; + +inline void swap(Blob& lhs, Blob& rhs) noexcept { + lhs.swap(rhs); +} + +inline std::ostream& operator<<(std::ostream& out, const Blob& v) { + return out << "Blob[" << v.TypeName() << "]"; +} + +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..62b915885a80a99934962bb7582a7ba7b7dd3754 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +struct IValue; +using Stack = std::vector; + +class OperatorHandle; +class KernelFunction; + +// This kernel implements the behavior of falling through to the next available +// registered dispatch key. The implementation of this function is FAST; it is +// no overhead to fallthrough to the next key. See cpp file for some more +// implementation notes; notably, this does NOT actually go through the +// boxing/unboxing codepath. +TORCH_API void fallthrough_kernel( + OperatorKernel*, + const OperatorHandle&, + DispatchKeySet, + Stack*); + +// Note [Ambiguity in AutogradOther kernel] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This error-reporting kernel is registered to the AutogradOther entry in the +// dispatch table when there is both a CompositeImplicitAutograd kernel and a +// backend kernel for ANY backend that maps to AutogradOther. To see why +// this is necessary in the AutogradOther case, it's helpful to first see +// why everything works out fine for a backend that has a reserved Autograd +// entry (see rule 2.2 in [Note] DispatchTable computation): +// +// CPU AutogradCPU +// reg? registers with... +// ------------------------------------------------- +// y Autograd registration takes precedence +// over CompositeImplicitAutograd. +// This is good, because the CPU specific backend +// implementation is more specialized and typically better; +// if we used the composite, we would bypass it. +// (NB: the Autograd key is guaranteed to exist because +// the autograd codegen requires it!) +// +// n CompositeImplicitAutograd takes precedence. +// This is also good, because the Autograd +// registration (if it exists) would try to redispatch +// to the (non-existent) CPU implementation; by +// using the composite, we ensure the operator +// actually works. +// +// As you can see, when we have a specific Autograd key (AutogradCPU), we can +// decide whether or not to use the CompositeImplicitAutograd kernel or the +// Autograd kernel based on whether or not the backend kernel exists. +// +// However, for AutogradOther (which is the catchall autograd kernel for +// everything that doesn't have a specific Autograd key), we can't do this +// trick because there isn't any unique backend to peek at to disambiguate; +// if there are some backends that have implementations they prefer Autograd, +// but unimplemented backends would prefer CompositeImplicitAutograd. Rather +// than arbitrarily pick one or the other, we just register a kernel that raises +// an error and let the user decide how to proceed. +TORCH_API void ambiguous_autogradother_kernel( + OperatorKernel*, + const OperatorHandle&, + DispatchKeySet, + Stack*); + +// Note [named_not_supported_kernel] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This kernel implements reporting an error message saying that named tensor is +// not supported. This kernel doesn't rely on the Stack, and so it is special +// cased in the dispatcher to be triggered before we attempt boxing (so we can +// give a good error message in cases when boxing is not supported). When +// boxing is universally supported this can be removed. +[[noreturn]] TORCH_API void named_not_supported_kernel( + OperatorKernel*, + const OperatorHandle&, + DispatchKeySet, + Stack*); + +/** + * BoxedKernel is similar to a std::function storing a boxed kernel. + */ +class TORCH_API BoxedKernel final { + public: + // This is how boxed kernels are actually stored + // + // Note [Plumbing Keys Through The Dispatcher] + // Benchmarks have shown that it is expensive for the dispatcher to read from + // thread-local storage (TLS) upon every dispatch call into order to compute + // which kernel to dispatch to. + // + // To mitigate this, we've updated the calling convention inside the + // dispatcher to expect every kernel that it stores to have a first argument + // of type DispatchKeySet. + // + // What are the invariants of the DispatchKeySet when it gets passed to a + // kernel? + // - All keys to the left of the current dispatch key have been masked out. + // (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the + // highest bit to be DispatchKey::Tracer) + // - All other keys that dispatcher normally would have computed through TLS + + // global state + op arguments + // are still in the set. + // + // Kernels can then opt into using this keyset to save the dispatcher from + // doing repeated work during redispatches: recalculating the highest-priority + // dispatch key, which involves reading from TLS. Instead, the kernels that + // opt in will calculate an updated DispatchKeySet directly from the old one, + // and pass the updated set directly into the dispatcher upon redispatching. + // + // This is an opt-in mechanism: Kernels can automatically opt in by setting + // the first argument in their signature to be of type DispatchKeySet. See the + // kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for + // examples. + // + // The mechanism for optionally passing that DispatchKeySet into the kernel + // lives in make_boxed_from_unboxed_functor.h. See Note [Plumbing Keys Through + // The Dispatcher 2] for details. + using InternalBoxedKernelFunction = + void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); + // This is the public API for how boxed kernels are defined + using BoxedKernelFunction = void(const OperatorHandle&, Stack*); + using BoxedKernelFunction_withDispatchKeys = + void(const OperatorHandle&, DispatchKeySet, Stack*); + + BoxedKernel(); + + // Fast path for dispatch to allow not touching the boxed kernel in + // the common case where unboxed is available. + bool isValid() const; + bool isFallthrough() const; + + /** + * Call the function with boxed arguments. + */ + void callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const; + + /** + * Create a KernelFunction from a boxed function. + * + * Example: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>(); + */ + template + static BoxedKernel makeFromFunction(); + + /** + * TODO: This will only be useful if we write a backend fallback that plumbs + * dispatch keys (currently there are none) See Note [Plumbing Keys Through + * The Dispatcher] for details. + */ + template + static BoxedKernel makeFromFunction(); + + /** + * Create a KernelFunction from a boxed functor. + * + * Example: + * + * > class MyFunctor final : public c10::OperatorKernel { + * > public: + * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} + * > }; + * > BoxedKernel func = + * BoxedKernel::makeFromFunctor(std::make_unique()); + */ + template + static BoxedKernel makeFromFunctor( + std::unique_ptr kernelFunctor); + + static BoxedKernel makeFallthrough(); + static BoxedKernel makeAmbiguousAutogradOther(); + static BoxedKernel makeNamedNotSupported(); + + private: + friend class KernelFunction; + + template + static void make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet, + Stack* stack); + + template + static void make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet, + Stack* stack); + + explicit BoxedKernel( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func); + + OperatorKernel* getFunctor() const; + InternalBoxedKernelFunction* getFnPtr() const; + + c10::intrusive_ptr functor_; + InternalBoxedKernelFunction* boxed_kernel_func_; +}; + +} // namespace c10 + +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1960607c6bc8ea1807d6b16a3eabb39ccc650d11 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h @@ -0,0 +1,106 @@ +#pragma once + +namespace c10 { + +inline BoxedKernel::BoxedKernel() : functor_(), boxed_kernel_func_(nullptr) {} + +inline BoxedKernel::BoxedKernel( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func) + : functor_(std::move(functor)), boxed_kernel_func_(boxed_kernel_func) {} + +template +inline void BoxedKernel::make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet, + Stack* stack) { + // Note that we're dropping the DispatchKeySet argument. + // See Note [Plumbing Keys Through The Dispatcher 2] for details. + func(opHandle, stack); +} + +template +inline void BoxedKernel::make_boxed_function( + OperatorKernel*, + const OperatorHandle& opHandle, + DispatchKeySet ks, + Stack* stack) { + // See Note [Plumbing Keys Through The Dispatcher 2] for details. + func(opHandle, ks, stack); +} + +inline bool BoxedKernel::isValid() const { + return boxed_kernel_func_ != nullptr; +} + +inline bool BoxedKernel::isFallthrough() const { + return boxed_kernel_func_ == &fallthrough_kernel; +} + +inline void BoxedKernel::callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + boxed_kernel_func_ != nullptr, + "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel."); + (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunction() { + return BoxedKernel( + nullptr, // no functor_ object + &make_boxed_function); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunction() { + return BoxedKernel( + nullptr, // no functor_ object + &make_boxed_function); +} + +inline BoxedKernel BoxedKernel::makeFallthrough() { + return BoxedKernel( + nullptr, // no functor_ object + &fallthrough_kernel); +} + +inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() { + return BoxedKernel( + nullptr, // no functor_ object + &ambiguous_autogradother_kernel); +} + +inline BoxedKernel BoxedKernel::makeNamedNotSupported() { + return BoxedKernel( + nullptr, // no functor_ object + &named_not_supported_kernel); +} + +template +inline BoxedKernel BoxedKernel::makeFromFunctor( + std::unique_ptr kernelFunctor) { + static_assert( + std::is_base_of_v, + "Tried to call BoxedKernel::makeFromFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + return BoxedKernel( + std::move(kernelFunctor), + [](OperatorKernel* kernel, + const OperatorHandle& op, + DispatchKeySet ks, + Stack* stack) { + (*static_cast(kernel))(op, ks, stack); + }); +} + +inline OperatorKernel* BoxedKernel::getFunctor() const { + return functor_.get(); +} +inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const { + return boxed_kernel_func_; +} + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..06bcc5d4f49b88ad60c7927108c0a906282c84f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction.h @@ -0,0 +1,283 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack + // to the c10 namespace. + +class OperatorHandle; +struct OperatorKernel; +class KernelFunction; + +template +using has_symint = std::disjunction< + std::is_same, + std::is_same, + std::is_same, + std::is_same, T>>; + +template +struct remove_symint { + using type = T; +}; + +template <> +struct remove_symint { + using type = int64_t; +}; + +template <> +struct remove_symint { + using type = OptionalIntArrayRef; +}; + +template <> +struct remove_symint { + using type = c10::IntArrayRef; +}; + +template <> +struct remove_symint> { + using type = std::optional; +}; + +template +struct maybe_keep_symint final {}; + +template +struct maybe_keep_symint { + using type = T; +}; + +template +struct maybe_keep_symint { + using type = typename remove_symint::type; +}; + +template +using fn_has_symint = typename guts::typelist::true_for_any_type< + has_symint, + typename guts::infer_function_traits::type::parameter_types>; + +template +struct fn_remove_symint; + +template +struct fn_remove_symint { + using type = Ret(typename remove_symint::type...); +}; + +/** + * KernelFunction is similar to std::function but stores a kernel function. + * You can create a KernelFunction from a boxed or unboxed + * function/functor/lambda and call it in a boxed or unboxed way. If the way it + * was created doesn't match the way it was called, it will do boxing or + * unboxing as necessary. + */ +class TORCH_API KernelFunction final { + public: + using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction; + using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction; + using BoxedKernelFunction_withDispatchKeys = + BoxedKernel::BoxedKernelFunction_withDispatchKeys; + + KernelFunction(); + + // Fast path for dispatch to allow not touching the boxed kernel in + // the common case where unboxed is available. + bool isValidUnboxed() const; + bool isValidSymUnboxed() const; + bool isValid() const; + bool isFallthrough() const; + + /** + * Call the function in a boxed way. + * If the kernel function was created with an unboxed function, + * this will call an unboxing wrapper which then calls into that + * unboxed function. + * + * Example: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); + * > Tensor result = func.callBoxed(stack); + * + * Or, with an unboxed implementation: + * + * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( + * > [] (Tensor a, bool b) -> Tensor {...}); + * > Tensor result = func.callBoxed(stack); + */ + void callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const; + + /** + * Call the function in an unboxed way. + * If the kernel function was created with a boxed function, + * this will box all inputs and then call into that boxed function. + * + * Note that this doesn't work for all types yet. + * + * Example: + * + * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( + * > [] (Tensor a, bool b) -> Tensor {...}); + * > Tensor result = func.call(tensor1, true); + * + * Or, with a boxed implementation: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func); + * > Tensor result = func.call(tensor1, true); + */ + template + Return call( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) const; + + /** + * Create a KernelFunction from a BoxedKernel. + */ + static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn); + + /** + * Create a KernelFunction from a boxed function. + * + * Example: + * + * > void boxed_func(OperatorKernel*, Stack* stack) {...} + * > KernelFunction func = + * KernelFunction::makeFromBoxedFunction<&boxed_func>(); + */ + template + static KernelFunction makeFromBoxedFunction(); + + /** + * TODO: This will only be useful if we write a backend fallback that plumbs + * dispatch keys (currently there are none) See Note [Plumbing Keys Through + * The Dispatcher] for details. + */ + template + static KernelFunction makeFromBoxedFunction(); + + /** + * Create a KernelFunction from an unboxed functor. + * + * Example: + * + * > class MyFunctor final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > KernelFunction func = + * KernelFunction::makeFromUnboxedFunctor(std::make_unique()); + */ + template + static KernelFunction makeFromUnboxedFunctor( + std::unique_ptr kernelFunctor); + + /** + * Create a KernelFunction from a boxed functor. + * + * Example: + * + * > class MyFunctor final : public c10::OperatorKernel { + * > public: + * > void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...} + * > }; + * > KernelFunction func = + * KernelFunction::makeFromBoxedFunctor(std::make_unique()); + */ + template + static KernelFunction makeFromBoxedFunctor( + std::unique_ptr kernelFunctor); + + /** + * Create a KernelFunction from an unboxed function. + * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction + * because knowing the function pointer as a template argument (i.e. at + * compile time) allows the compiler to inline the function into its + * unboxing wrapper and yields better performance when calling the function. + * + * Example: + * + * > Tensor unboxed_func(Tensor a, Tensor b) {...} + * > KernelFunction func = + * KernelFunction::makeFromUnboxedFunction(); + */ + template + static KernelFunction makeFromUnboxedFunction(FuncPtr); + + /** + * Create a KernelFunction from an unboxed function. + * KernelFunction::makeFromUnboxedFunction is usually a better choice than + * this if you know the function pointer at compile time, see doc comment + * there for an explanation. + * + * Example: + * + * > Tensor unboxed_func(Tensor a, Tensor b) {...} + * > KernelFunction func = + * KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func); + */ + template + static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func); + + static KernelFunction makeFallthrough(); + static KernelFunction makeAmbiguousAutogradOther(); + static KernelFunction makeNamedNotSupported(); + + /** + * Create a KernelFunction from an unboxed lambda. + * + * Example: + * + * > KernelFunction func = KernelFunction::makeFromUnboxedLambda( + * > [] (Tensor a, bool b) -> Tensor {...}); + */ + template + static std::enable_if_t< + guts::is_stateless_lambda>::value, + KernelFunction> + makeFromUnboxedLambda(Lambda&& lambda); + template + static std::enable_if_t< + !guts::is_stateless_lambda>::value, + KernelFunction> + makeFromUnboxedLambda(Lambda&& lambda); + + std::string dumpState() const; + // For testing internal invariants only + bool _equalsBoxedAndUnboxed(const KernelFunction&) const; + + private: + explicit KernelFunction( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func); + explicit KernelFunction( + BoxedKernel boxed_fn, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func); + + BoxedKernel boxed_kernel_func_; + void* unboxed_kernel_func_; + void* sym_unboxed_kernel_func_; +}; + +} // namespace c10 + +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..df49d6227ee93ce014ad7b90db0b6cf997061edb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h @@ -0,0 +1,320 @@ +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +namespace detail { +template +std::enable_if_t< + !std::is_array_v && !std::is_array_v && + std::is_base_of_v, + std::unique_ptr> +make_unique_base(Args&&... args) { + return std::unique_ptr(new Child(std::forward(args)...)); +} +} // namespace detail + +inline KernelFunction::KernelFunction() + : boxed_kernel_func_(), + unboxed_kernel_func_(nullptr), + sym_unboxed_kernel_func_(nullptr) {} + +inline KernelFunction::KernelFunction( + std::unique_ptr functor, + InternalBoxedKernelFunction* boxed_kernel_func, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func = nullptr) + : boxed_kernel_func_(std::move(functor), boxed_kernel_func), + unboxed_kernel_func_(unboxed_kernel_func), + sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} + +inline KernelFunction::KernelFunction( + BoxedKernel boxed_fn, + void* unboxed_kernel_func, + void* sym_unboxed_kernel_func = nullptr) + : boxed_kernel_func_(std::move(boxed_fn)), + unboxed_kernel_func_(unboxed_kernel_func), + sym_unboxed_kernel_func_(sym_unboxed_kernel_func) {} + +inline bool KernelFunction::isValidUnboxed() const { + return unboxed_kernel_func_ != nullptr; +} + +inline bool KernelFunction::isValidSymUnboxed() const { + return sym_unboxed_kernel_func_ != nullptr; +} + +inline bool KernelFunction::isValid() const { + return boxed_kernel_func_.isValid(); +} + +inline bool KernelFunction::isFallthrough() const { + return boxed_kernel_func_.isFallthrough(); +} + +inline void KernelFunction::callBoxed( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Stack* stack) const { + boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack); +} + +template +inline Return callUnboxedKernelFunction( + void* unboxed_kernel_func, + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + Args&&... args) { + using ActualSignature = Return(OperatorKernel*, DispatchKeySet, Args...); + ActualSignature* func = + reinterpret_cast(unboxed_kernel_func); + return (*func)(functor, dispatchKeySet, std::forward(args)...); +} + +// This template requires you to explicitly specify the argument you want to +// forward; it doesn't work if you try to deduce it +// NB: keep this in sync with cloneWithRealTypes in function_schema.cpp + +template +inline typename remove_symint::type unpackSymInt(T x) { + return x; +} + +template <> +inline typename remove_symint::type unpackSymInt(c10::SymInt x) { + return x.guard_int(__FILE__, __LINE__); +} + +template <> +inline typename remove_symint::type unpackSymInt( + c10::SymIntArrayRef x) { + return C10_AS_INTARRAYREF_SLOW(x); +} + +template <> +inline typename remove_symint>::type unpackSymInt( + std::optional x) { + return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__)) + : std::nullopt; +} + +template <> +inline typename remove_symint::type unpackSymInt( + at::OptionalSymIntArrayRef x) { + return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) + : std::nullopt; +} + +template +C10_ALWAYS_INLINE Return KernelFunction::call( + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) const { + // note: Args above is intentionally not Args&&. We don't want perfect + // forwarding, which would require Args to be deduced, but instead we + // want callers to explicitly specify the Args. + + if constexpr (std::disjunction_v...>) { + if (sym_unboxed_kernel_func_ != nullptr) { + auto* functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction( + sym_unboxed_kernel_func_, + functor, + dispatchKeySet, + std::forward(args)...); + } + + if (unboxed_kernel_func_ != nullptr) { + auto* functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction< + Return, + typename remove_symint::type...>( + unboxed_kernel_func_, + functor, + dispatchKeySet, + unpackSymInt(args)...); + } + } else { + if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) { + auto* functor = boxed_kernel_func_.getFunctor(); + return callUnboxedKernelFunction( + unboxed_kernel_func_, + functor, + dispatchKeySet, + std::forward(args)...); + } + } + + return impl::BoxedKernelWrapper::call( + boxed_kernel_func_, + opHandle, + dispatchKeySet, + std::forward(args)...); +} + +inline KernelFunction KernelFunction::makeFromBoxedKernel( + BoxedKernel boxed_fn) { + return KernelFunction( + std::move(boxed_fn), nullptr); // no unboxed function pointer +} + +template +inline KernelFunction KernelFunction::makeFromBoxedFunction() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunction()); +} + +template +inline KernelFunction KernelFunction::makeFromBoxedFunction() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunction()); +} + +inline KernelFunction KernelFunction::makeFallthrough() { + return KernelFunction::makeFromBoxedKernel(BoxedKernel::makeFallthrough()); +} + +inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeAmbiguousAutogradOther()); +} + +inline KernelFunction KernelFunction::makeNamedNotSupported() { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeNamedNotSupported()); +} + +template +inline KernelFunction KernelFunction::makeFromUnboxedFunctor( + std::unique_ptr kernelFunctor) { +#ifndef NDEBUG + // This assertion is costly for build time so it's debug-gated. + static_assert( + guts::is_functor::value, + "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor."); +#endif + static_assert( + std::is_base_of_v, + "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + + auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed::call; + void* void_unboxed_fn = reinterpret_cast(unboxed_fn); + bool is_symint = fn_has_symint::value; + return KernelFunction( + std::move(kernelFunctor), + &impl::make_boxed_from_unboxed_functor:: + call, + is_symint ? nullptr : void_unboxed_fn, + is_symint ? void_unboxed_fn : nullptr); +} + +template +inline KernelFunction KernelFunction::makeFromBoxedFunctor( + std::unique_ptr kernelFunctor) { + return KernelFunction::makeFromBoxedKernel( + BoxedKernel::makeFromFunctor(std::move(kernelFunctor))); +} + +template +inline KernelFunction KernelFunction::makeFromUnboxedFunction( + FuncPtr func_ptr) { + static_assert( + is_compile_time_function_pointer::value, + "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN."); + static_assert( + !std::is_same_v, + "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); +#if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__) + TORCH_INTERNAL_ASSERT( + FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); +#else + static_assert( + FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); +#endif + +#if !defined(C10_MOBILE) + (void)func_ptr; // Suppress unused variable warning + return makeFromUnboxedFunctor< + AllowLegacyTypes, + typename impl::WrapFunctionIntoFunctor::type>( + detail::make_unique_base< + OperatorKernel, + typename impl::WrapFunctionIntoFunctor::type>()); +#else + // On mobile, we rather want to optimize for binary size than for performance, + // so let's not inline the kernel into the wrapper but use + // makeFromUnboxedRuntimeFunction instead. + return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr()); +#endif +} + +template +inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction( + FuncType* func) { + static_assert( + guts::is_function_type::value, + "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type."); + static_assert( + !std::is_same_v, + "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); + TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr"); + + return makeFromUnboxedFunctor< + AllowLegacyTypes, + impl::WrapFunctionIntoRuntimeFunctor>>( + detail::make_unique_base< + OperatorKernel, + impl::WrapFunctionIntoRuntimeFunctor>>(func)); +} + +template +inline std::enable_if_t< + guts::is_stateless_lambda>::value, + KernelFunction> +KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { + static_assert( + guts::is_functor>::value, + "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); + +#if !defined(C10_MOBILE) + return makeFromUnboxedFunctor< + AllowLegacyTypes, + impl::WrapFunctionIntoRuntimeFunctor>>( + detail::make_unique_base< + OperatorKernel, + impl::WrapFunctionIntoRuntimeFunctor>>( + std::forward(lambda))); +#else + // On mobile, we rather want to optimize for binary size than for performance, + // so let's not inline the kernel into the wrapper but use + // makeFromUnboxedRuntimeFunction instead. + using FuncType = + typename guts::infer_function_traits_t>::func_type; + return makeFromUnboxedRuntimeFunction(lambda); +#endif +} + +template +inline std::enable_if_t< + !guts::is_stateless_lambda>::value, + KernelFunction> +KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { + static_assert( + guts::is_functor>::value, + "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); + + return makeFromUnboxedFunctor< + AllowLegacyTypes, + impl::WrapFunctionIntoRuntimeFunctor>>( + detail::make_unique_base< + OperatorKernel, + impl::WrapFunctionIntoRuntimeFunctor>>( + std::forward(lambda))); +} + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..2d0620fa0d9f1257bf36b3a4c6962e97a07e64c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h @@ -0,0 +1,27 @@ +#pragma once +#include + +namespace c10 { + +/** + * Inherit from OperatorKernel to implement a c10 kernel. + * + * Example: + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * + * The kernel class is allowed to have members but these are equivalent + * to global variables. The kernel implementation is responsible for + * preventing race conditions on them. + * + * See below for how to register this kernel with PyTorch. + */ +struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target { + ~OperatorKernel() override = default; +}; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h new file mode 100644 index 0000000000000000000000000000000000000000..8b81ae4e517df706c96f286e49f40357c6abf7d3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +namespace c10::impl { +namespace detail { +template +class WrapFunctionIntoFunctor_ {}; +template +class WrapFunctionIntoFunctor_< + FuncPtr, + ReturnType, + guts::typelist::typelist> + final : public c10::OperatorKernel { + public: + C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) { + return (*FuncPtr::func_ptr())(std::forward(args)...); + } +}; +} // namespace detail + +// WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel +// functor. Since it is a compile time function pointer, many compilers can +// inline it into the wrapper and you don't get any performance overhead for +// wrapping. +template +struct WrapFunctionIntoFunctor final { + static_assert( + c10::is_compile_time_function_pointer::value, + "WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN."); + using type = detail::WrapFunctionIntoFunctor_< + FuncPtr, + typename guts::function_traits::return_type, + typename guts::function_traits< + typename FuncPtr::FuncType>::parameter_types>; +}; + +} // namespace c10::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h new file mode 100644 index 0000000000000000000000000000000000000000..c5609eada7af3e54c5f36002895a782c340ff51d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace c10::impl { + +namespace detail { +template +class WrapFunctionIntoRuntimeFunctor_ {}; +template +class WrapFunctionIntoRuntimeFunctor_< + FuncType, + ReturnType, + guts::typelist::typelist> + final : public c10::OperatorKernel { + public: + template + explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func) + : kernel_func_(std::forward(kernel_func)) {} + + decltype(auto) operator()(Parameters... args) { + return kernel_func_(std::forward(args)...); + } + + private: + FuncType kernel_func_; +}; +} // namespace detail + +// WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that +// inherits from c10::OperatorKernel, so it can be used as a c10 kernel. +// This can, for example, be used for lambdas, functors or even function +// pointers. In the case of function pointers, since it is a runtime function +// pointer, there is an overhead for calling it whenever the kernel is invoked. +template +using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_< + FuncType, + typename guts::infer_function_traits_t::return_type, + typename guts::infer_function_traits_t::parameter_types>; + +} // namespace c10::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/boxing.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/boxing.h new file mode 100644 index 0000000000000000000000000000000000000000..68e25cccd44c87f30de9052832c0ea0ceb5aa40c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/boxing.h @@ -0,0 +1,410 @@ +#pragma once + +// This file contains boxing (not unboxing) logic, +// i.e. how to make a vector from a set of concrete arguments. + +#include +#include +#include + +#include + +#include +#include + +namespace c10::impl { + +// +// utils +// + +// is_mutable_tensor_ref +template +struct is_mutable_tensor_ref : std::false_type {}; +template <> +struct is_mutable_tensor_ref : std::true_type {}; + +// is_tuple_of_mutable_tensor_refs +// +template +struct is_tuple_of_mutable_tensor_refs : std::false_type {}; + +template +struct is_tuple_of_mutable_tensor_refs< + T, + std::enable_if_t::value, void>> + : guts::typelist:: + all> {}; + +// has_ivalue_to tests the presence/absence of instance method +// IValue::to() +// +template +struct has_ivalue_to : std::false_type {}; + +template +struct ivalue_to_helper { + using type = decltype(std::declval().template to()); +}; +template +using ivalue_to_helper_t = typename ivalue_to_helper::type; + +template +struct has_ivalue_to>> : std::true_type {}; + +// +// boxing predicates +// + +// A boxable arg type is one that IValue has a constructor for. +template +using can_box = std::disjunction< + std::is_constructible>, + // TensorOptions are not directly constructible into IValue, + // but torch::jit::push knows how to handle them + std::is_same>>; + +template +using can_box_all = std::conjunction...>; + +// an unboxable result is one that can be extracted from an IValue +template +using can_unbox = std::conjunction< + std::disjunction< + has_ivalue_to, + // void returns are ok + std::is_same>, + std::negation>>; + +// +// boxArgs - utility for pushing unboxed args onto IValue stack +// +template +torch::jit::Stack boxArgs(Args... args) { + // TODO Reuse stack vector instead of allocating? + torch::jit::Stack stack; + stack.reserve(sizeof...(Args)); + torch::jit::push(stack, std::forward(args)...); + return stack; +} + +template +inline constexpr size_t boxed_size_one() { + static_assert( + !std::is_same_v, c10::TensorOptions>, + "need to patch this path to support TensorOptions passed by reference"); + return 1; +} + +// torch::jit::push pushes 4 values for a TensorOptions; this needs to +// be kept in sync. +template <> +inline constexpr size_t boxed_size_one() { + return 4; +} + +// NOTE: this could probably be simplified with C++17 fold expressions. +template +struct BoxedSize : std::integral_constant {}; +template +struct BoxedSize + : std::integral_constant< + size_t, + boxed_size_one() + BoxedSize::value> {}; + +template +static inline constexpr size_t boxed_size() { + return BoxedSize::value; +} + +template +C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValue*& dest, T& arg) { + new (dest++) IValue(arg); +} + +C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack( + IValue*& dest, + c10::TensorOptions options) { + new (dest++) IValue(c10::typeMetaToScalarType(options.dtype())); + new (dest++) IValue(options.layout()); + new (dest++) IValue(options.device()); + new (dest++) IValue(options.pinned_memory()); +} + +inline void boxArgsToStack(IValue*&) {} + +template +C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack( + IValue*& dest, + T& arg, + Args&... args) { + boxToStack(dest, arg); + boxArgsToStack(dest, args...); +} + +// +// PopResult is a helper class whose specializations handle popping single and +// multiple return values, respectively. +// +template +struct PopResult final { + static Result call(Stack& stack) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return one value on the stack, ", + "but instead pushed ", + stack.size(), + " values."); + return std::move(stack[0]).to(); + } +}; + +template +struct PopResult> final { + using Result = std::tuple; + + static Result call(Stack& stack) { + // for tuple return types, boxed kernel has pushed multiple values onto the + // stack + constexpr int RetCount = sizeof...(Types); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == RetCount, + "Boxed kernel was expected to return ", + RetCount, + " values on the stack, ", + "but instead pushed ", + stack.size(), + " values."); + return pop_to_tuple_impl(stack, std::make_index_sequence()); + } + + private: + // note: this has been moved into its own helper only to avoid a parse error + // on `indices` otherwise. I'm sure there's an incantation that slips it past + // the parser but eh + template + static Result pop_to_tuple_impl( + Stack& stack, + std::index_sequence) { + return std::make_tuple((std::move(stack[indices]).template to())...); + } +}; + +// +// BoxedKernelWrapper +// +// For a given function type FT, BoxedKernelWrapper implements +// a `call` method that +// - takes a boxed kernel and unboxed arguments as specified by FT, +// - calls `boxArgs` to box the arguments +// - calls the boxed kernel +// - unboxes and returns the result +// +// The partial specializations below handle various cases: in +// particular, not all types appearing in op signatures are supported, +// and ops returning references have nonstandard wrapper implementations. +// + +// 1. The base specialization of BoxedKernelWrapper should never be +// instantiated. A "no call method defined on BoxedKernelWrapper" compile error +// means that an op signature has failed to trigger any of the partial +// specializations that follow this one. +// +template +struct BoxedKernelWrapper { + // The reason we're not just doing straight up static_assert(false, ...) here: + // Basically, the way to make sure a static_assert only fires if a template + // is actually instantiated (rather than every time the file is parsed) is to + // use template parameters in the expression, e.g. FuncType here. However, + // since `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the + // same effect. + static_assert( + sizeof(FuncType) != sizeof(FuncType), + "Function signature contains one or more unsupported parameter and/or return types. " + "Look for a nearby error like " + "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" " + "- (your function type) is the unsupported signature."); +}; + +// +// 2. Supported signatures, other than those involving non-const Tensor refs - +// i.e., "functional" ops. +// + +template +struct BoxedKernelWrapper< + Result(Args...), + std::enable_if_t< + can_box_all::value && can_unbox::value && + !is_tuple_of_mutable_tensor_refs::value, + void>> { + static Result call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) { + torch::jit::Stack stack = boxArgs(std::forward(args)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + + if constexpr (!std::is_same_v) { + // op has pushed one or more values onto the stack. + return PopResult::call(stack); + } else { + // op returns void, boxed kernel has pushed nothing onto stack. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.empty(), + "Boxed kernel was expected to return no values on the stack, ", + "but instead returned ", + stack.size(), + " values."); + } + } +}; + +// +// 3. in-place ops take a single non-const Tensor reference +// as their first argument, and return it. +// +// Note: all signatures matching this pattern are assumed to be for such ops. +// Because of this, the generated BoxedKernelWrapper specializations simply +// return the in-place argument. +// + +template +struct BoxedKernelWrapper< + at::Tensor&(at::Tensor&, OtherArgs...), + std::enable_if_t::value, void>> { + static at::Tensor& call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + at::Tensor& outArg, + OtherArgs... otherArgs) { + torch::jit::Stack stack = boxArgs( + outArg, std::forward(otherArgs)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return a single value on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + return outArg; + } +}; + +// +// 3.5. In-process migration to make in-place ops take and return +// const references instead. +template +struct BoxedKernelWrapper< + const at::Tensor&(const at::Tensor&, OtherArgs...), + std::enable_if_t::value, void>> { + static const at::Tensor& call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + const at::Tensor& outArg, + OtherArgs... otherArgs) { + torch::jit::Stack stack = boxArgs(outArg, otherArgs...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return a single value on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + return outArg; + } +}; + +// +// 4. out of place ops that take a single non-const Tensor reference as their +// final argument, and also return it. +// +// Note: all signatures matching this pattern are assumed to be for such ops. +// This assumption permits the generated BoxedKernelWrapper specializations to +// simply return out arguments. +// +template +struct BoxedKernelWrapper< + at::Tensor&(FirstArg, RestArgs...), + std::enable_if_t< + can_box_all::value + // this skips over in-place kernels with a non-const Tensor + // arg at the front, so those can unambiguously trigger the + // preceding specialization. + && !is_mutable_tensor_ref::value, + void>> { + static at::Tensor& call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + FirstArg firstArg, + RestArgs... restArgs) { + torch::jit::Stack stack = boxArgs( + std::forward(firstArg), std::forward(restArgs)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == 1, + "Boxed kernel was expected to return a single value on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + // reusing restArgs after it has been forwarded here is ok because we know + // that the last element is of type `Tensor&`. + return std::get( + std::tuple{restArgs...}); + } +}; + +// +// 5. out of place ops that take multiple non-const Tensor references as their +// final arguments, and return them in a std::tuple. +// +// Note: all signatures matching this pattern are assumed to be for such ops. +// This assumption permits the generated BoxedKernelWrapper specializations to +// simply return the out arguments. +// +template +struct BoxedKernelWrapper< + Result(Args...), + std::enable_if_t< + can_box_all::value && + is_tuple_of_mutable_tensor_refs::value, + void>> { + static Result call( + const BoxedKernel& boxed_kernel_func, + const OperatorHandle& opHandle, + DispatchKeySet dispatchKeySet, + Args... args) { + using ArgTuple = std::tuple; + constexpr int RetCount = std::tuple_size(); + + torch::jit::Stack stack = boxArgs(std::forward(args)...); + boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + stack.size() == RetCount, + "Boxed kernel was expected to return ", + RetCount, + " values on the stack, ", + "but instead returned ", + stack.size(), + " values."); + + // reusing args after it has been forwarded here is ok because we know + // that the last RetCount elements are of type `Tensor&`. + auto result = guts::tuple_take( + ArgTuple{std::forward(args)...}); + static_assert( + std::is_same_v, + "The parameter list of an op returning a tuple of Tensor references " + "must end with an equal number of Tensor reference parameters."); + return result; + } +}; + +} // namespace c10::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..e67d1badc9a46504c3fdfe080ae2591d3960d650 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -0,0 +1,785 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace c10 { + +using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack + // to the c10 namespace. +class OperatorHandle; + +/* + * [Note: Argument forwarding in the dispatcher] + * + * The dispatcher uses a somewhat unusual way to forward arguments through + * several layers of wrapper functions. This can be confusing because an + * experienced C++ programmer would look at this and think "oh this is supposed + * to be forwarding a universal reference but the && is missing. This is a + * bug.". It is not a bug. The common way in C++ to forward arguments is to use + * universal references: + * + * > template void func(T&& arg) { func2(std::forward(arg)); } + * + * but that relies on inferring the correct reference type (i.e. value vs & vs + * &&) from the argument. In our case, we cannot rely on the argument as + * supplied by the caller, because that could infer a different reference type + * than was used in the kernel function. The correct reference type is dictated + * by the kernel signature and must be identical since we cast function pointers + * through void* pointers and mismatches would be UB. So we need a forwarding + * pattern that determines the reference type to use by looking at the + * explicitly supplied operator signature, not by looking at the argument we're + * calling it with. + * + * What does std::forward do, exactly? + * ------------------------------------ + * std::forward(t) is a way to cast t to the reference type supplied in T. + * Let's assume decay_t == U and T is either U or some reference of U. + * - std::forward(t) will return U&, no matter what kind of reference t is. + * - std::forward(t) will return U&&, no matter what kind of reference t + * is. + * - std::forward(t) will return U&& (not U!), no matter what kind of + * reference t is. + * + * For universal references, that means that in the following function + * > template void func(T&& arg) { func2(std::forward(arg)); } + * + * - when called with arg being a rvalue reference or non-reference value, T + * gets inferred to be a non-reference U, and std::forward(t) will return + * U&&, correctly moving the argument. + * - when called with arg behind a lvalue reference, T gets inferred to be U& + * because that's the only way to match the signature (in C++, a type that is + * (T&)&& will collapse to T&). That means std::forward(t) will return U& and + * the value will not be moved but passed on as a lvalue reference. + * + * How do we use that? + * ------------------------------------ + * But std::forward can also be used outside of the common "universal + * forwarding" pattern to change reference types. So instead of following the + * common C++ pattern, we notice what std::forward() actually does, and that + * is it takes a value and changes its reference to the type of reference passed + * in as T. If we don't infer T but explicitly specify it, we can use this to + * forward based on an explicitly specified reference type instead of the + * inferred argument type. + * + * This is why many of the dispatcher functions look like + * > template func(T t) { func2(std::forward(t)); } + * instead of the common + * > template func(T&& t) { func2(std::forward(t)); } + * + * and are expected to be called by explicitly specifying the template + * parameters in a way that matches the expected operator signature at each call + * site. + */ + +namespace impl { +// supported_primitive_arg_types defines which primitive types we allow in +// kernel functions as arguments or returns. +// Additionally, we support lists, dicts and optionals containing these types. +using supported_primitive_arg_types = guts::typelist::typelist< + int64_t, + double, + bool, + std::string_view, + at::Tensor, + at::Scalar, + c10::QScheme, + c10::ScalarType, + c10::Device, + c10::DeviceIndex, + c10::Layout, + c10::MemoryFormat, + at::Dimname>; + +// We have an unboxed functor in hand that takes C++ arguments, and +// we're building a boxed functor wrapper for it that takes IValues. +// So "outside" is boxed and "inside" is unboxed. +// +// So a valid input type is one that our boxed functor wrapper can +// unbox from an IValue into a C++ value. +// +// Whereas a valid output type is one that our wrapper can recieve +// as a C++ value from the unboxed functor, and box into an IValue. + +// +// assert_is_valid_input_type +// checks that T can be unboxed from an IValue into a C++ value. +// + +template +struct assert_is_valid_input_type { + assert_is_valid_input_type() { + if constexpr (guts::typelist::contains:: + value) { + /* everything is ok, this is a primitive type */ + } else { + /* otherwise this must be an instance of a valid custom class, since it + can only have been created via IValue(x), which ensures this. */ + } + } +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type {}; + +template +struct TypeCheckHelper; + +template +struct TypeCheckHelper {}; + +template +struct TypeCheckHelper + : TypeCheckHelper { + assert_is_valid_input_type check; +}; + +template +struct assert_is_valid_input_type< + std::tuple, + AllowDeprecatedTypes> + : TypeCheckHelper {}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported input type: Dict where Key is invalid. We only support int64_t, double, bool, and string."); +}; + +template +struct assert_is_valid_input_type< + std::unordered_map, + AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + AllowDeprecatedTypes, + "You tried to register a kernel with an unsupported input type: std::unordered_map. Please use Dict instead."); + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported input type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string."); +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: List. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: ArrayRef. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_input_type< + c10::OptionalArrayRef, + AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: OptionalArrayRef. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_input_type, AllowDeprecatedTypes> + : assert_is_valid_input_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported input type: std::array. Please use std::array instead."); +}; + +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + // There is no reason to support float when we have double. Keep the API lean. + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported input type: const char*. Please use std::string_view instead."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t, T>>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported input type: vector. Please use List instead."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t< + std::is_integral_v && + !guts::typelist::contains::value>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string."); +}; +template +struct assert_is_valid_input_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead."); +}; + +// TODO: it probably would be good to tighten this up quite a bit more with +// an explicit list for everything + +// +// assert_is_valid_output_type +// + +template +struct assert_is_valid_output_type { + assert_is_valid_output_type() { + if constexpr (guts::typelist::contains:: + value) { + /* everything is ok, this is a primitive type */ + } else { + /* otherwise T is verified to be a registered custom class in the IValue + constructor, so no benefit in double-checking here */ + } + } +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type {}; + +template +struct assert_is_valid_output_type< + c10::OptionalArrayRef, + AllowDeprecatedTypes> + : assert_is_valid_output_type {}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported output type: Dict where Key is invalid. We only support int64_t, double, bool, and string."); + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: Dict. Please use Dict or Dict."); +}; + +template +struct assert_is_valid_output_type< + std::unordered_map, + AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + AllowDeprecatedTypes, + "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict instead."); + static_assert( + guts::typelist::contains::value, + "You tried to register a kernel with an unsupported output type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string."); + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict or Dict."); +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: List. Please use List, List or Tensor instead."); +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: std::vector. Please use List, List or Tensor instead."); + // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel + // with an unsupported output type: std::vector. Please use List + // instead."); +}; + +template +struct assert_is_valid_output_type, AllowDeprecatedTypes> + : assert_is_valid_output_type { + static_assert( + !std::is_same_v, + "You tried to register a kernel with an unsupported output type: std::array. Please use std::array instead."); +}; + +// The following specialisations of assert_is_valid_output_type are technically +// not necessary since we would hit the base case and show an error message +// there if they didn't exist, but we can show a better error message +// in some common error scenarios. +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + // There is no reason to support float when we have double. Keep the API lean. + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string."); +}; +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported output type: const char*. Please use std::string_view instead."); +}; +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t, T>>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported output type: vector. Please use List instead."); +}; +template +struct assert_is_valid_output_type< + T, + AllowDeprecatedTypes, + std::enable_if_t< + std::is_integral_v && + !guts::typelist::contains::value>> { + static_assert( + guts::false_t::value, + "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string."); +}; + +// ivalue_to_arg + +template +struct decay_if_not_tensor final { + using type = std::decay_t; +}; + +template <> +struct decay_if_not_tensor final { + using type = at::Tensor&; +}; + +template <> +struct decay_if_not_tensor final { + using type = const at::Tensor&; +}; + +template +struct ivalue_to_arg final { + static decltype(auto) call(IValue& v) { + assert_is_valid_input_type(); + return std::move(v).to(); + } +}; + +// The following two specializations take advantage of specialized +// `toTensor()` overloads on IValue to avoid copying. +template +struct ivalue_to_arg final { + // We cannot use the default implementation if they asked for a + // `at::Tensor&` because it moves from the IValue, so it can't get + // an lvalue reference. + static at::Tensor& call(IValue& v) { + // Tensor& is valid, don't bother asserting + return v.toTensor(); + } +}; + +template +struct ivalue_to_arg final { + // We should not use the default implementation if they asked for + // a `const at::Tensor&` because it moves from the IValue and they + // didn't ask for that. + static const at::Tensor& call(IValue& v) { + // const Tensor& is valid, don't bother asserting + return v.toTensor(); + } +}; + +template +struct ivalue_to_arg final { + static List call(IValue& v) { + return v.toTensorList(); + } +}; + +template +struct ivalue_to_arg, AllowDeprecatedTypes> final { + // If an argument is ArrayRef, convert the IValue to a std::vector and + // pass that to the operator. std::vector is implicitly convertible to + // ArrayRef. + static std::vector call(IValue& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + } +}; +template +struct ivalue_to_arg final { + static std::vector call(IValue& v) { + if (v.isIntList()) { + std::vector r; + auto src = v.toIntList(); + std::transform( + src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { + return c10::SymInt(i); + }); + return r; + } else { + return ivalue_to_arg, AllowDeprecatedTypes>:: + call(v); + } + } +}; +template +struct ivalue_to_arg, AllowDeprecatedTypes> + final { + static OptionalArray call(IValue& v) { + if (v.isIntList()) { + std::vector r; + auto src = v.toIntList(); + std::transform( + src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { + return c10::SymInt(i); + }); + return OptionalArray(std::move(r)); + } else { + return std::move(v).to>(); + } + } +}; +template +struct ivalue_to_arg>, AllowDeprecatedTypes> final { + // If an argument is std::optional>, convert the IValue to an + // std::optional> and pass that to the operator. + // OptionalArray is basically a std::optional> but + // implicitly convertible to std::optional>. + static OptionalArray call(IValue& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + } +}; + +template +struct ivalue_to_arg, AllowDeprecatedTypes> final { + // If an argument is OptionalArrayRef, convert the IValue to an + // std::optional> and pass that to the operator. + // OptionalArray is basically a std::optional> but + // implicitly convertible to OptionalArrayRef + static OptionalArray call(IValue& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + } +}; + +// return_to_ivalue +template +struct return_to_ivalue final {}; + +template +struct return_to_ivalue< + T, + AllowDeprecatedTypes, + std::enable_if_t>> + final { + static IValue call(T&& v) { + assert_is_valid_output_type(); + return c10::ivalue::from(std::move(v)); + } + static IValue copy(const T& v) { + assert_is_valid_output_type(); + return IValue(v); + } +}; + +// Special case to allow kernels to return `Tensor&`. +// TODO Delete this once kernels don't do that anymore +template +struct return_to_ivalue final { + static IValue call(at::Tensor& v) { + return c10::ivalue::from(v); + } + static IValue copy(at::Tensor& v) { + return IValue(v); + } +}; + +// wrap_kernel_functor_unboxed_ + +template +struct wrap_kernel_functor_unboxed_ final {}; + +// This specialization is for kernels with a first argument that is NOT of type +// DispatchKeySet This includes kernels with 0 arguments. +template +struct wrap_kernel_functor_unboxed_< + KernelFunctor, + ReturnType(ParameterTypes...)> + final { + static_assert( + std::is_same_v< + ReturnType, + typename guts::infer_function_traits_t::return_type>, + "Return type mismatch"); + static_assert( + std::is_same_v< + guts::typelist::typelist, + typename guts::infer_function_traits_t< + KernelFunctor>::parameter_types>, + "Parameter types mismatch"); + + // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes + // doesn't use && + static ReturnType call( + OperatorKernel* functor, + DispatchKeySet, + ParameterTypes... args) { + KernelFunctor* functor_ = static_cast(functor); + // Note [Plumbing Keys Through The Dispatcher 2] + // See Note [Plumbing Keys Through The Dispatcher] for the background. + // This functor explicitly takes in a dispatchKeySet and drops it on the + // floor- it does not forward it to the registered kernel. + // + // This is due to the calling convention within the dispatcher, which + // expects all registered kernels to have a first argument of type + // DispatchKeySet. + // This is not the case for pretty much all manually written kernels, + // however- this functor serves to separate the calling convention of the + // dispatcher from the calling convention of manually written kernels. + return (*functor_)(std::forward(args)...); + } +}; + +// This specialization is for kernels with a first argument of type +// DispatchKeySet +template +struct wrap_kernel_functor_unboxed_< + KernelFunctor, + ReturnType(DispatchKeySet, ParameterTypes...)> + final { + static_assert( + std::is_same_v< + ReturnType, + typename guts::infer_function_traits_t::return_type>, + "Return type mismatch"); + static_assert( + std::is_same_v< + guts::typelist::typelist, + typename guts::infer_function_traits_t< + KernelFunctor>::parameter_types>, + "Parameter types mismatch"); + + // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes + // doesn't use && + static ReturnType call( + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + ParameterTypes... args) { + KernelFunctor* functor_ = static_cast(functor); + // We're explicitly taking in a dispatchKeySet and forwarding it to the + // registered kernel. See Note [Plumbing Keys Through The Dispatcher 2] for + // details. + return (*functor_)(dispatchKeySet, std::forward(args)...); + } +}; + +template +using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_< + KernelFunctor, + typename guts::infer_function_traits_t::func_type>; + +// call_functor_with_args_from_stack + +template < + class Functor, + bool AllowDeprecatedTypes, + size_t... ivalue_arg_indices, + typename... ArgTypes> +std::decay_t::return_type> +call_functor_with_args_from_stack_( + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + Stack* stack, + std::index_sequence, + guts::typelist::typelist*) { + (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would + // be unused and we have to silence the compiler warning. + + // We're explicitly filtering out DispatchKeySet from the argument list. + // Some kernels take a DispatchKeySet as their first argument in order to + // plumb keys through the dispatcher. We don't want to expose the + // DispatchKeySet type to jit, so we don't include this argument on the stack. + // See Note [Plumbing Keys Through The Dispatcher] for the background. + return wrap_kernel_functor_unboxed::call( + functor, + dispatchKeySet, + ivalue_to_arg< + typename decay_if_not_tensor::type, + AllowDeprecatedTypes>:: + call(torch::jit::peek( + *stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))...); +} + +template +std::decay_t::return_type> +call_functor_with_args_from_stack( + OperatorKernel* functor, + DispatchKeySet dispatchKeySet, + Stack* stack) { + // We're explicitly filtering out DispatchKeySet from the argument list. + // Some kernels take a DispatchKeySet as their first argument in order to + // plumb keys through the dispatcher. We don't want to expose the + // DispatchKeySet type to jit, so we don't include this argument on the stack. + // See Note [Plumbing Keys Through The Dispatcher] for the background. + using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func< + Functor>::parameter_types; + constexpr size_t num_ivalue_args = guts::typelist::size::value; + return call_functor_with_args_from_stack_( + functor, + dispatchKeySet, + stack, + std::make_index_sequence(), + static_cast(nullptr)); +} + +// push_outputs + +template +struct push_outputs final { + // Contrary to [Note: Argument forwarding in the dispatcher], we use + // OutputType&& here to avoid one extra call to the move constructor in this + // case. This is still not a universal reference though because OutputType is + // an explicitly specified class template parameter. + static void call(OutputType&& output, Stack* stack) { + torch::jit::push( + *stack, + return_to_ivalue::call( + std::forward(output))); + } + static void copy(const OutputType& output, Stack* stack) { + torch::jit::push( + *stack, + return_to_ivalue::copy(output)); + } +}; +template +struct push_outputs, AllowDeprecatedTypes> final { + static void call(std::tuple&& output, Stack* stack) { + call_( + std::move(output), + stack, + std::make_index_sequence()); + } + static void copy(const std::tuple& output, Stack* stack) { + copy_(output, stack, std::make_index_sequence()); + } + + private: + template + static void call_( + std::tuple&& output, + Stack* stack, + std::index_sequence) { + torch::jit::push( + *stack, + return_to_ivalue::call( + std::forward(std::get(output)))...); + } + template + static void copy_( + const std::tuple& output, + Stack* stack, + std::index_sequence) { + torch::jit::push( + *stack, + return_to_ivalue::copy( + std::get(output))...); + } +}; +template +struct push_outputs final { + static void call(int /*dummy*/, Stack* /*stack*/) {} + static void copy(int /*dummy*/, Stack* /*stack*/) {} +}; + +// make_boxed_from_unboxed_functor + +template +struct make_boxed_from_unboxed_functor final { + static_assert( + std::is_base_of_v, + "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + + static void call( + OperatorKernel* functor, + const OperatorHandle&, + DispatchKeySet dispatchKeySet, + Stack* stack) { + using ReturnType = + typename guts::infer_function_traits_t::return_type; + // We're explicitly filtering out DispatchKeySet from the argument list. + // Some kernels take a DispatchKeySet as their first argument in order to + // plumb keys through the dispatcher. We don't want to expose the + // DispatchKeySet type to jit, so we don't include this argument on the + // stack. See Note [Plumbing Keys Through The Dispatcher] for the + // background. + using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func< + KernelFunctor>::parameter_types; + constexpr bool has_outputs = !std::is_same_v; + constexpr size_t num_inputs = guts::typelist::size::value; + if constexpr (has_outputs) { + // Decay ReturnType to ReturnType_ so that if a reference gets returned, + // we actually store it by value and don't get a dangling reference. This + // is only required because some kernels still return `Tensor&`. [Note: + // VC++ and 'std': ambiguous symbol] + using ReturnType_ = ::std::decay_t; + ReturnType_ output = call_functor_with_args_from_stack< + KernelFunctor, + AllowDeprecatedTypes>(functor, dispatchKeySet, stack); + torch::jit::drop(*stack, num_inputs); + // See note [ VC++ and 'std': ambiguous symbol] + push_outputs::call( + ::std::move(output), stack); + } else { + call_functor_with_args_from_stack( + functor, dispatchKeySet, stack); + torch::jit::drop(*stack, num_inputs); + } + } +}; +} // namespace impl + +} // namespace c10 + +namespace torch { +using OperatorKernel = c10::OperatorKernel; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..2c389ac4c75b5c54bf7b2259ac4db849d82f426a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h @@ -0,0 +1,140 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +template +inline std::vector makeStack(Inputs&&... inputs) { + return {std::forward(inputs)...}; +} + +inline at::Tensor dummyTensor( + c10::DispatchKeySet ks, + bool requires_grad = false) { + auto* allocator = c10::GetCPUAllocator(); + int64_t nelements = 1; + auto dtype = caffe2::TypeMeta::Make(); + int64_t size_bytes = nelements * dtype.itemsize(); + auto storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes, + allocator->allocate(size_bytes), + allocator, + /*resizable=*/true); + at::Tensor t = + at::detail::make_tensor(storage_impl, ks, dtype); + // TODO: We add this to simulate the ideal case where we only have Autograd + // backend keys + // on Tensor when it requires grad. But currently Autograd keys are + // added in TensorImpl constructor by default. + if (!requires_grad) { + t.unsafeGetTensorImpl()->remove_autograd_key(); + } + return t; +} + +inline at::Tensor dummyTensor( + c10::DispatchKey dispatch_key, + bool requires_grad = false) { + return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad); +} + +template +inline std::vector callOp( + const c10::OperatorHandle& op, + Args... args) { + auto stack = makeStack(std::forward(args)...); + op.callBoxed(&stack); + return stack; +} + +template +inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) { + return op.typed().call(std::forward(args)...); +} + +template +inline Result callOpUnboxedWithDispatchKey( + const c10::OperatorHandle& op, + c10::DispatchKey dispatchKey, + Args... args) { + return op.typed().callWithDispatchKey( + dispatchKey, std::forward(args)...); +} + +template +inline Result callOpUnboxedWithPrecomputedDispatchKeySet( + const c10::OperatorHandle& op, + c10::DispatchKeySet ks, + Args... args) { + return op.typed().redispatch( + ks, std::forward(args)...); +} + +inline void expectDoesntFindKernel( + const char* op_name, + c10::DispatchKey dispatch_key) { + auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); + EXPECT_ANY_THROW(callOp(*op, dummyTensor(dispatch_key), 5);); +} + +inline void expectDoesntFindOperator(const char* op_name) { + auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); + EXPECT_FALSE(op.has_value()); +} + +template +inline void expectThrows(Functor&& functor, const char* expectMessageContains) { + try { + std::forward(functor)(); + } catch (const Exception& e) { + EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains)); + return; + } + ADD_FAILURE() << "Expected to throw exception containing \"" + << expectMessageContains << "\" but didn't throw"; +} + +template +void expectListEquals(c10::ArrayRef expected, std::array actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +template +void expectListEquals(c10::ArrayRef expected, c10::ArrayRef actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +template +void expectListEquals(c10::ArrayRef expected, c10::List actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual.get(i)); + } +} + +template +void expectListEquals(c10::ArrayRef expected, std::vector actual) { + EXPECT_EQ(expected.size(), actual.size()); + for (const auto i : c10::irange(expected.size())) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +// NB: This is not really sound, but all of the type sets constructed here +// are singletons so it's fine +static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) { + return legacyExtractDispatchKey(t.key_set()); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/builtin_function.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/builtin_function.h new file mode 100644 index 0000000000000000000000000000000000000000..5ab1ace1685f8b73f05ff3b647a10037b01d2644 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/builtin_function.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +struct BuiltinOpFunction : public Function { + BuiltinOpFunction( + c10::QualifiedName qualname, + c10::FunctionSchema schema, + std::function callable, + std::string doc_string = "") + : name_(std::move(qualname)), + callable_(std::move(callable)), + schema_(std::move(schema)), + doc_string_(std::move(doc_string)) { + TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1); + } + + std::string_view doc_string() const override { + return doc_string_; + } + + void run(Stack& stack) override { + callable_(stack); + } + + c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher /* not used */) override { + run(stack); + auto res = c10::make_intrusive(stack.front().type()); + res->markCompleted(std::move(stack.front())); + return res; + } + + const c10::QualifiedName& qualname() const override { + return name_; + } + + // if this isn't yet defined, run its method_creator function + void ensure_defined() override { + // nop + } + + const c10::FunctionSchema& getSchema() const override { + return schema_; + } + + size_t num_inputs() const override { + return schema_.arguments().size(); + } + + Function& setSchema(c10::FunctionSchema schema) override { + schema_ = std::move(schema); + return *this; + } + + bool call( + Stack& stack, + std::optional, + c10::function_ref) override { + run(stack); + return false; + } + + bool call(Stack& stack, c10::function_ref) + override { + run(stack); + return false; + } + + ~BuiltinOpFunction() override = default; + + private: + c10::QualifiedName name_; + + std::function callable_; + + c10::FunctionSchema schema_; + + std::string doc_string_; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/class_type.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/class_type.h new file mode 100644 index 0000000000000000000000000000000000000000..ea124fc6eb079558c4682fcb62726a3ab8d019ba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/class_type.h @@ -0,0 +1,441 @@ +#pragma once + +#include + +#include +#include +#include + + +namespace torch::jit { +struct CompilationUnit; +struct Function; +} // namespace torch::jit + + +namespace c10 { + +struct FunctionSchema; + +// This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither. +// This state is mutually exclusive. Buffers and Parameters can only appear on modules. +enum class AttributeKind { + BUFFER, + PARAMETER, + REGULAR_ATTRIBUTE +}; + +// This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr). +// Note: This structure does not represent the value of the attribute. +struct TORCH_API ClassAttribute { + public: + ClassAttribute(AttributeKind kind, + TypePtr attributeType, + std::string attributeName) : + kind_(kind), + attributeType_(std::move(attributeType)), + attributeName_(std::move(attributeName)) {} + + AttributeKind getKind() const { + return kind_; + } + + const TypePtr& getType() const { + return attributeType_; + } + + const std::string& getName() const { + return attributeName_; + } + + private: + AttributeKind kind_; + TypePtr attributeType_; + std::string attributeName_; +}; + +/** + * User Defined Types + */ + +struct ClassType; +using ClassTypePtr = std::shared_ptr; +using ::torch::jit::CompilationUnit; + +// This represents a class in TorchScript. +struct TORCH_API ClassType : public NamedType { + // This represents an attribute of a class; a name associated with an attribute, and a + // getter and (optional) setter for that attribute. + struct Property { + std::string name; + torch::jit::Function* getter; + torch::jit::Function* setter; + }; + + // Create a class type with name `name` and its methods stored in `cu`. + static ClassTypePtr create( + std::optional qualifiedName, + std::weak_ptr cu, + bool is_module = false, + std::string doc_string = "", + std::vector unresolved_class_attributes = {}); + + bool equals(const Type& rhs) const override { + if (this == &rhs) { + return true; + } + if (auto user_rhs = rhs.castRaw()) { + const auto& lhs_name = name(); + const auto& rhs_name = user_rhs->name(); + return lhs_name.has_value() && lhs_name == rhs_name && + this->compilation_unit() == user_rhs->compilation_unit(); + } + return false; + } + + std::string str() const override { + return annotation_str(); + } + + std::string repr_str() const override { + std::stringstream ss; + ss << str() + << " (of Python compilation unit at: " << compilation_unit().get() << ")"; + return ss.str(); + } + + const std::vector& methods() const; + + TypePtr findAttribute(const std::string& name) const { + size_t pos = 0; + for (const auto& attr : attributes_) { + if (name == attr.getName()) { + break; + } + ++pos; + } + + if (pos >= attributes_.size()) { + return nullptr; + } + return attributes_[pos].getType(); + } + + const TypePtr& getAttribute(const std::string& name) const { + auto slot = findAttributeSlot(name); + TORCH_CHECK( + slot, + repr_str(), + " does not have an attribute with name '", + name, + "'"); + return attributes_[*slot].getType(); + } + + size_t numAttributes() const { + return attributes_.size(); + } + + const TypePtr& getAttribute(size_t slot) const { + AT_ASSERT(slot < attributes_.size()); + return attributes_.at(slot).getType(); + } + + const std::string getAttributeName(size_t slot) const { + AT_ASSERT(slot < attributes_.size()); + return attributes_[slot].getName(); + } + + void checkNotExist(const std::string& name, const std::string& what) const; + + // Attributes are stored in a specific slot at runtime for effiency. + // When emitting instructions we specify the slot so that attribute access is + // a constant lookup + std::optional findAttributeSlot(const std::string& name) const { + size_t slot = 0; + for (const auto& attr : attributes_) { + if (name == attr.getName()) { + return slot; + } + slot++; + } + return std::nullopt; + } + size_t getAttributeSlot(const std::string& name) const { + if (auto r = findAttributeSlot(name)) { + return *r; + } + TORCH_CHECK( + false, + repr_str(), + " does not have an attribute with name '", + name, + "'"); + } + + bool hasAttribute(const std::string& name) const { + return std::find_if( + attributes_.cbegin(), + attributes_.cend(), + [&](const ClassAttribute& attr) { return attr.getName() == name; }) != + attributes_.cend(); + } + + bool isUnresolvedClassAttribute(const std::string& name) const; + + at::ArrayRef containedTypes() const override { + return attributeTypes_; + } + + size_t addAttribute( + const std::string& name, + TypePtr type, + bool is_parameter = false, + bool is_buffer = false); + + // [Internal Only] Remove attribute from the ClassType, + // caller is responsible to make sure the modification is safe: + // it is unsafe to having existing allocations + // of this object around anymore, and any code that works on + // the attribute is now invalid. Only newly created code is + // valid again. + void unsafeRemoveAttribute(const std::string& name); + + // [Internal Only] Change the type of an attribute of the ClassType, + // The caller is responsible to make sure the modification is safe: + // it is unsafe to maintain uses of the old type of the attribute, + // and any code that works on the attribute is now invalid. + // Only newly created code is valid again. + void unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty); + + // Add attribute \p NAME if it doesn't exist or verify that it has a + // compatible type otherwise. + size_t addOrCheckAttribute( + const std::string& name, + TypePtr ty, + bool is_parameter = false, + bool is_buffer = false) { + auto slot_idx = findAttributeSlot(name); + if (!slot_idx) { + return addAttribute(name, std::move(ty), is_parameter, is_buffer); + } + + TORCH_CHECK( + is_parameter == this->is_parameter(*slot_idx), + "Parameter field mismatch for the field '", + name, + "'"); + const TypePtr& atype = getAttribute(*slot_idx); + TORCH_CHECK( + ty->isSubtypeOf(*atype), + ty->repr_str(), + " is not compatible with the type ", + atype->repr_str(), + " for the field '", + name, + "'"); + return *slot_idx; + } + + // Get the property with the given \p name, if it exists on the class. + std::optional getProperty(const std::string& name); + // Add a property named \p name with \p getter and \p setter as its getter and setter. + void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter); + // Get a list of all properties. + const std::vector& properties() const { + return properties_; + } + + bool hasConstant(const std::string& name) const { + return std::find_if( + constantNames_.cbegin(), + constantNames_.cend(), + [&](const std::string& constant) { return constant == name; }) != + constantNames_.cend(); + } + + size_t addConstant(const std::string& name, const IValue& value); + + std::optional findConstantSlot(const std::string& name) const; + + size_t getConstantSlot(const std::string& name) const { + if (auto r = findConstantSlot(name)) { + return *r; + } + TORCH_CHECK( + false, + repr_str(), + " does not have constant field with the name '", + name, + "'"); + } + + const std::string& getConstantName(size_t slot) const; + + const std::string& doc_string() const { + return doc_string_; + } + + IValue getConstant(const std::string& name) const; + + IValue getConstant(size_t slot) const; + + std::optional findConstant(const std::string& name) const; + + size_t numConstants() const; + + at::ArrayRef constantNames() const { + return constantNames_; + } + + at::ArrayRef constantValues() const; + + // [Internal Only] Remove constant from the ClassType + // caller is responsible to make sure the modification is safe: + // it is unsafe to having existing allocations + // of this object around anymore, and any code that works on + // the attribute is now invalid. Only newly created code is + // valid again. + void unsafeRemoveConstant(const std::string& name); + + TypePtr createWithContained(std::vector contained_types) const override { + auto ptr = ClassType::create(name(), compilation_unit_, is_module()); + AT_ASSERT(numAttributes() == contained_types.size()); + for(size_t i = 0; i < attributes_.size(); ++i) { + AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i])); + ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i])); + } + // Copy methods over + for (const auto& method : methods()) { + ptr->addMethod(method); + } + return ptr; + } + + bool is_module() const override { + return isModule_; + } + + const std::vector& getAttributes() const { + return attributes_; + } + + bool is_parameter(size_t slot) const { + TORCH_INTERNAL_ASSERT( + is_module(), "asking for parameterSlots of non-Module"); + return attributes_.at(slot).getKind() == AttributeKind::PARAMETER; + } + + bool is_buffer(size_t slot) const { + TORCH_INTERNAL_ASSERT( + is_module(), "asking for bufferWrittenSlots of non-Module"); + return attributes_.at(slot).getKind() == AttributeKind::BUFFER; + } + + void addForwardPreHook(torch::jit::Function* pre_hook_ptr); + void addForwardHook(torch::jit::Function* hook_ptr); + torch::jit::Function* findForwardPreHook(const std::string& name) const; + torch::jit::Function* findForwardHook(const std::string& name) const; + const std::vector& getForwardHooks() const; + const std::vector& getForwardPreHooks() const; + + void checkForwardPreHookSchema( + size_t pre_hook_idx, + const FunctionSchema& pre_hook_schema) const; + void checkForwardHookSchema( + size_t hook_idx, + const FunctionSchema& hook_schema) const; + + void addMethod(torch::jit::Function* method); + torch::jit::Function* findMethod(const std::string& name) const; + torch::jit::Function& getMethod(const std::string& name) const; + torch::jit::Function* findHook(const std::string& name) const; + torch::jit::Function& getHook(const std::string& name) const; + bool hasMethod(const std::string& name) const; + + torch::jit::Function* findStaticMethod(const std::string& name) const; + void addStaticMethod(torch::jit::Function* method); + + // [Internal Only] Remove method from the ClassType + // caller is responsible to make sure the modification is safe: + // it is unsafe to having existing allocations + // of this object around anymore, and any code that works on + // the attribute is now invalid. Only newly created code is + // valid again. + // Note this method is intended for freezing only. + void unsafeRemoveMethod(const std::string& name); + + std::shared_ptr compilation_unit(); + + std::shared_ptr compilation_unit() const; + + // generate a refined version of this class. + // It has the same name but the slot Types are subtypes of + // the original slots. It is only valid to refine a class type in a context + // where it is know that there are not assignments to the objects slots + // that would invalidate the refinement. + // These variants are not registered in the global class table. + ClassTypePtr refine(at::ArrayRef refined_slots) const; + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + static const TypeKind Kind = TypeKind::ClassType; + + private: + ClassType( + std::optional name, + std::weak_ptr cu, + bool is_module = false, + std::string doc_string = "", + std::vector unresolved_class_attributes = {}); + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name()->qualifiedName(); + } + + void addAttribute(ClassAttribute classAttribute); + std::string getForwardPreHookErrorMessage(size_t pre_hook_idx) const; + std::string getForwardHookErrorMessage(size_t hook_idx) const; + + // Mapping of attribute names -> their type. + // NOTE: this does not contain methods, which are stored in the module + // TODO: once modules support arbitrary ivalue attributes, we don't need this + // anymore. + // TODO: This is better represented as an OrderedDict, but alas it is not yet + // available from c10 + + // Mapping of constant names -> their value. + std::vector constantNames_; + std::vector constantValues_; + // Holds method attributes + std::weak_ptr compilation_unit_; + + // Holds all atrributes, attribute details are found on ClassAttribute + std::vector attributes_; + // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef. + // Never fill this without using the appropriate provideNewClassAttribute method + std::vector attributeTypes_; + + // List of methods associated with this class. + std::vector methods_; + std::vector staticmethods_; + + // List of hooks to be run before/after forward. + std::vector forward_hooks_; + std::vector forward_pre_hooks_; + + // List of properties exposed by this class. + std::vector properties_; + + bool isModule_ = false; + + // Doc string of class. + std::string doc_string_; + + // For error reporting accesses to class level attributes. + std::vector unresolved_class_attributes_; +}; + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/custom_class.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/custom_class.h new file mode 100644 index 0000000000000000000000000000000000000000..ff9bda981b2906e55449e93a582266888c2eb258 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/custom_class.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +struct ClassType; +using ClassTypePtr = std::shared_ptr; + +TORCH_API c10::ClassTypePtr getCustomClassTypeImpl(const std::type_index &tindex); + +template +const c10::ClassTypePtr& getCustomClassType() { + // Classes are never unregistered from getCustomClassTypeMap and the + // hash lookup can be a hot path, so just cache. + // For the same reason, it's fine If this ends up getting duplicated across + // DSO boundaries for whatever reason. + static c10::ClassTypePtr cache = getCustomClassTypeImpl( + std::type_index(typeid(T))); + return cache; +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/CppSignature.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/CppSignature.h new file mode 100644 index 0000000000000000000000000000000000000000..e7695aa5c21f4098b2f7c2971d7a20b03b51fa73 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/CppSignature.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10::impl { + +// A CppSignature object holds RTTI information about a C++ function signature +// at runtime and can compare them or get a debug-printable name. +class TORCH_API CppSignature final { + public: + CppSignature(const CppSignature&) = default; + CppSignature(CppSignature&&) noexcept = default; + CppSignature& operator=(const CppSignature&) = default; + CppSignature& operator=(CppSignature&&) noexcept = default; + + template + static CppSignature make() { + // Normalize functors, lambdas, function pointers, etc. into the plain + // function type The first argument of the schema might be of type + // DispatchKeySet, in which case we remove it. We do this to guarantee that + // all CppSignature's for an operator will match, even if they're registered + // with different calling conventions. + // See Note [Plumbing Keys Through The Dispatcher] + using decayed_function_type = + typename c10::remove_DispatchKeySet_arg_from_func< + std::decay_t>::func_type; + + return CppSignature(std::type_index(typeid(decayed_function_type))); + } + + std::string name() const { + return c10::demangle(signature_.name()); + } + + friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) { + if (lhs.signature_ == rhs.signature_) { + return true; + } + // Without RTLD_GLOBAL, the type_index comparison could yield false because + // they point to different instances of the RTTI data, but the types would + // still be the same. Let's check for that case too. + // Note that there still is a case where this might not work, i.e. when + // linking libraries of different compilers together, they might have + // different ways to serialize a type name. That, together with a missing + // RTLD_GLOBAL, would still fail this. + if (0 == strcmp(lhs.signature_.name(), rhs.signature_.name())) { + return true; + } + + return false; + } + + private: + explicit CppSignature(std::type_index signature) + : signature_(std::move(signature)) {} + std::type_index signature_; +}; + +inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) { + return !(lhs == rhs); +} + +} // namespace c10::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h new file mode 100644 index 0000000000000000000000000000000000000000..ecc4bc7b5d893ccd38e26165255f47eb59988556 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h @@ -0,0 +1,279 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +namespace impl { + +// Take a DispatchKeySet for a Tensor and determine what the actual dispatch +// DispatchKey should be, taking into account TLS, and skipping backends which +// fall through. +// +// Unlike Tensor::key_set(), the value of this on a tensor can change depending +// on TLS. +// +// NB: If there is no valid dispatch key, this will return Undefined +inline DispatchKeySet computeDispatchKeySet( + DispatchKeySet ks, + // The key mask lets us eliminate (by zero entries) keys which should not + // be considered for dispatch. There are two cases when we use this: + // + // - If an operator's dispatch table contains a fallthrough entry, we + // should bypass it entirely when finding the key + // - If a user invokes with redispatch, the mask lets us + // zero out the key the user asked us to stop. + // + // These excluded backends are NOT tracked in the TLS, but must be applied + // AFTER TLS (since the backend may have been introduced for consideration + // by the included TLS), which is why you have to pass them in to this + // function (as opposed to just applying it to the input 'ks'). + DispatchKeySet key_mask) { + c10::impl::LocalDispatchKeySet local = + c10::impl::tls_local_dispatch_key_set(); + // TODO: It's a bit irritating that we have to do logical ORs here, it would + // be nice to only do one. Can always_included be folded into the TLS? Well, + // it's a bit troublesome, because fastpath TLS access requires the type of + // the TLS in question to be zero-initialized, so you don't actually win + // anything in that case. + return (((ks | local.included_) - local.excluded_) & key_mask); +} + +} // namespace impl + +namespace detail { +// A small gadget to extract the DispatchKeySet from types which are known +// to have it. Used to extract dispatch keys from unboxed calls. +struct MultiDispatchKeySet : at::IterArgs { + DispatchKeySet ts; + void operator()(const at::Tensor& x) { + ts = ts | x.key_set(); + } + void operator()(const std::optional& x) { + if (x.has_value()) { + ts = ts | x->key_set(); + } + } + void operator()(at::ArrayRef xs) { + for (const auto& x : xs) { + ts = ts | x.key_set(); + } + } + // Tensor?[] translates to this case. + void operator()(const c10::List>& xs) { + for (std::optional x : xs) { + if (x.has_value()) { + ts = ts | x.value().key_set(); + } + } + } + // Structured Tensor[] translates to this case + void operator()(const at::ITensorListRef& xs) { + for (const auto& x : xs) { + ts = ts | x.key_set(); + } + } + [[noreturn]] void operator()(at::ArrayRef>) { + // Just checking that the handling of Tensor?[] didn't change. + TORCH_INTERNAL_ASSERT(false); + } + void operator()(const at::Generator& gen) { + if (gen.defined()) { + ts = ts | gen.key_set(); + } + } + void operator()(const std::optional& gen) { + if (gen.has_value() && gen->defined()) { + ts = ts | gen->key_set(); + } + } + template + void operator()(const T&) { + // do nothing + } +}; + +// NB: take by const reference (Don't do universal forwarding here! You +// don't want to move into this function!) +template +DispatchKeySet multi_dispatch_key_set(const Args&... args) { + return MultiDispatchKeySet().apply(args...).ts; +} +} // namespace detail + +/** + * An instance of DispatchKeyExtractor knows how to get a dispatch key given + * a list of arguments for an operator call. + * + * The instance is specific for a certain operator as: + * - In boxed dispatch, different operators have different ways to extract + * the dispatch key (e.g. different numbers of arguments), and we precompute + * the stack locations we should look at; and + * - In all dispatch, some backends should be excluded from dispatch because + * they have been registered as fallthrough. The set of excluded backends + * varies from operator, as some operators may have overridden the + * fallthrough with custom behavior. + * + * Note - this should maintain identical impl to the py dispatcher key + * extraction logic at pytorch/torch/dispatcher.py + */ +struct TORCH_API DispatchKeyExtractor final { + public: + static DispatchKeyExtractor make(const FunctionSchema& schema) { + return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema)); + } + + static DispatchKeyExtractor makeUninitialized() { + return DispatchKeyExtractor(c10::utils::bitset()); + } + + void registerSchema(const FunctionSchema& schema) { + TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset()); + dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema); + } + void deregisterSchema() { + dispatch_arg_indices_reverse_ = c10::utils::bitset(); + } + + DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const { + DispatchKeySet ks; + dispatch_arg_indices_reverse_.for_each_set_bit([&](size_t + reverse_arg_index) { + const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1); + if (C10_LIKELY(ivalue.isTensor())) { + // NB: Take care not to introduce a refcount bump (there's + // no safe toTensorRef method, alas) + ks = ks | ivalue.unsafeToTensorImpl()->key_set(); + } else if (C10_UNLIKELY(ivalue.isTensorList())) { + // NB: use toListRef as it doesn't induce refcount bumps + // (toTensorListRef is not a thing) + for (const auto& nv : ivalue.toListRef()) { + auto* tensor = nv.unsafeToTensorImpl(); + ks = ks | tensor->key_set(); + } + } + // Tensor?[] translates to a c10::List so we need to peek inside + else if (C10_UNLIKELY(ivalue.isList())) { + for (const auto& elt : ivalue.toListRef()) { + if (elt.isTensor()) { + ks = ks | elt.toTensor().key_set(); + } + } + } + }); + // Keys that are fallthrough should be skipped + if (requiresBitsetPerBackend_) { + c10::impl::LocalDispatchKeySet tls = + c10::impl::tls_local_dispatch_key_set(); + auto backend_idx = + ((ks | tls.included_) - tls.excluded_).getBackendIndex(); + return impl::computeDispatchKeySet( + ks, nonFallthroughKeysPerBackend_[backend_idx]); + } else { + return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + } + } + + template + DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const { + auto ks = detail::multi_dispatch_key_set(args...); + // Keys that are fallthrough should be skipped + if (requiresBitsetPerBackend_) { + c10::impl::LocalDispatchKeySet tls = + c10::impl::tls_local_dispatch_key_set(); + auto backend_idx = + ((ks | tls.included_) - tls.excluded_).getBackendIndex(); + return impl::computeDispatchKeySet( + ks, nonFallthroughKeysPerBackend_[backend_idx]); + } else { + return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); + } + } + + void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough); + + std::string dumpState() const; + void checkInvariants(const FunctionSchema& schema) const; + + private: + static bool isDispatchType(const Type& type) { + // Checking isSubtypeOf on a DynamicType heap-allocates a + // DynamicType version of the argument if it's not a DynamicType + // already, and this has measurable overhead during startup. +#ifdef C10_MOBILE + struct CachedTypes { + DynamicTypePtr listOfTensors; + DynamicTypePtr listOfOptionalTensors; + DynamicTypePtr optionalOfTensor; + }; + static const CachedTypes ct = { + DynamicType::create(*ListType::ofTensors()), + DynamicType::create(*ListType::ofOptionalTensors()), + DynamicType::create(*OptionalType::ofTensor())}; + return type.isSubtypeOf(c10::TypeFactory::get()) || + type.isSubtypeOf(ct.listOfTensors) || + type.isSubtypeOf(ct.listOfOptionalTensors) || + type.isSubtypeOf(ct.optionalOfTensor); +#else // C10_MOBILE + return type.isSubtypeOf(*TensorType::get()) || + type.isSubtypeOf(*ListType::ofTensors()) || + type.isSubtypeOf(*ListType::ofOptionalTensors()) || + type.isSubtypeOf(*OptionalType::ofTensor()); +#endif // C10_MOBILE + } + static c10::utils::bitset makeBitsetForDispatchArgs( + const FunctionSchema& schema) { + TORCH_CHECK( + schema.arguments().size() <= c10::utils::bitset::NUM_BITS(), + "The function schema has ", + schema.arguments().size(), + " arguments but this PyTorch build only supports ", + c10::utils::bitset::NUM_BITS()); + c10::utils::bitset dispatch_arg_indices_reverse; + for (const auto index : c10::irange(schema.arguments().size())) { + if (isDispatchType(*schema.arguments()[index].type())) { + dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index); + } + } + return dispatch_arg_indices_reverse; + } + + explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) + : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse), + nonFallthroughKeys_(DispatchKeySet::FULL) { + for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { + nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; + } + } + + // this is a bitset that has ones for each argument index which has to be + // considered for dispatch. This avoids having to iterate over the stack + // to find all the tensors. The bits are stored in reverse order, i.e. + // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from + // the top of the stack (i.e. the i-th last argument of the function) + // is relevant for dispatch. + // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just + // means you must do the fallthrough + c10::utils::bitset dispatch_arg_indices_reverse_; + + // Set of functionality keys for which the operator does NOT have fallthrough + // kernel. + DispatchKeySet nonFallthroughKeys_; + // Set of functionality keys for which the operator does NOT have fallthrough + // kernel, defined PER BACKEND. This is only needed if we know that the + // operator has a different set of fallthroughs defined for some backends. + std::array nonFallthroughKeysPerBackend_; + // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast + // path), or if we need to fall back to the slower path and check + // nonFallthroughKeysPerBackend_ + bool requiresBitsetPerBackend_{false}; +}; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h new file mode 100644 index 0000000000000000000000000000000000000000..590c9ba8d11d82965d66afe4ed01c909eb6208a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h @@ -0,0 +1,937 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifndef NDEBUG +#include +#endif + +namespace c10 { + +TORCH_API bool show_dispatch_trace(); +TORCH_API void dispatch_trace_nesting_incr(); +TORCH_API void dispatch_trace_nesting_decr(); +TORCH_API int64_t dispatch_trace_nesting_value(); + +struct DispatchTraceNestingGuard { + DispatchTraceNestingGuard() { + dispatch_trace_nesting_incr(); + } + ~DispatchTraceNestingGuard() { + dispatch_trace_nesting_decr(); + } +}; + +class TORCH_API OperatorHandle; +template +class TypedOperatorHandle; + +/** + * Implement this interface and register your instance with the dispatcher + * to get notified when operators are registered or deregistered with + * the dispatcher. + * + * NB: registration events only occur when a 'def' occurs; we don't trigger + * on 'impl' or 'fallback' calls. + */ +class TORCH_API OpRegistrationListener { + public: + virtual ~OpRegistrationListener(); + + virtual void onOperatorRegistered(const OperatorHandle& op) = 0; + virtual void onOperatorDeregistered(const OperatorHandle& op) = 0; +}; + +namespace detail { +class RegistrationListenerList; +} +class SchemaRegistrationHandleRAII; + +/** + * Top-level dispatch interface for dispatching via the dynamic dispatcher. + * Most end users shouldn't use this directly; if you're trying to register + * ops look in op_registration + */ +class TORCH_API Dispatcher final { + private: + // For direct access to backend fallback information + friend class impl::OperatorEntry; + + struct OperatorDef final { + explicit OperatorDef(OperatorName&& op_name) : op(std::move(op_name)) {} + + impl::OperatorEntry op; + + // These refer to the number of outstanding RegistrationHandleRAII + // for this operator. def_count reflects only def() registrations + // (in the new world, this should only ever be 1, but old style + // registrations may register the schema multiple times, which + // will increase this count). def_and_impl_count reflects the number + // of combined def() and impl() registrations. When the last def() gets + // unregistered, we must immediately call the Deregistered listeners, but we + // must not actually delete the handle as there are other outstanding RAII + // destructors which will try to destruct and they had better still have a + // working operator handle in this case + size_t def_count = 0; + size_t def_and_impl_count = 0; + }; + friend class OperatorHandle; + template + friend class TypedOperatorHandle; + + struct Guard final { + Guard() : alive(true), mutex() {} + std::atomic alive; + std::mutex mutex; + }; + + public: + ~Dispatcher(); + + // Implementation note: this class abstracts over the fact that we have + // per-operator dispatch tables. This could be easily adjusted to have a + // single global hash table. + static Dispatcher& realSingleton(); + + C10_ALWAYS_INLINE static Dispatcher& singleton() { +#if !defined C10_MOBILE + // Implemented inline so that steady-state code needn't incur + // function-call overhead. We can't just inline `realSingleton` + // because the function-local static would get duplicated across + // all DSOs that include & use this header, leading to multiple + // singleton instances. + static Dispatcher& s = realSingleton(); + return s; +#else + // For C10_MOBILE, we should never inline a static function that + // has a static member, since the generated code calls + // __cxa_guard_acquire and __cxa_guard_release which help + // implement exactly once semantics for the initialization of the + // static Dispatcher& s above (for the non-mobile case). That + // additional code when duplicated across all operator stubs + // for every backend results in a lot of additional code + // being generated by the compiler. + return realSingleton(); +#endif + } + + // ------------------------------------------------------------------------ + // + // Accessing operators by schema + // + // ------------------------------------------------------------------------ + + /** + * Looks for an operator schema with the given name and overload name + * and returns it if it is registered WITH A SCHEMA. + * Returns nullopt otherwise. + */ + std::optional findSchema(const OperatorName& operator_name); + + /** + * Variant of findSchema that results in less code generated at the call site. + * It (1) takes const char* pointer rather than OperatorName (so we skip + * generating std::string constructor calls at the call site), and (2) + * it raises an exception if the operator is not found (so we skip + * generating exception raising code at the call site) + * + * Irritatingly, we still have to generate the handful of instructions + * for dealing with an exception being thrown during static initialization + * (e.g. __cxa_guard_abort). If we could annotate this method noexcept we + * could avoid this code too, but as the name of the function suggests, + * it does throw exceptions. + */ + OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name); + + // Like findSchema, but also returns OperatorHandle even if there is no schema + std::optional findOp(const OperatorName& operator_name); + + // Returns a list of all operator names present in the operatorLookupTable_ + const std::vector getAllOpNames(); + + // Returns a list of all operator names present in the operatorLookupTable_ + // for a given dispatch key + const std::vector getAllOpNamesForDispatchKey(DispatchKey k); + + // ------------------------------------------------------------------------ + // + // Invoking operators + // + // ------------------------------------------------------------------------ + + template + Return call(const TypedOperatorHandle& op, Args... args) + const; + + template + static Return callWithDispatchKeySlowPath( + const TypedOperatorHandle& op, + at::StepCallbacks& stepCallbacks, + DispatchKeySet dispatchKeySet, + const KernelFunction& kernel, + Args... args); + + // Like call, but intended for use in a redispatch in kernels that have + // explicitly performed the DispatchKey update calculatulation. This will take + // the DispatchKeySet completely as is and dispatch to the kernel of the + // corresponding highest priority key in the set. Note that this version of + // redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask + // out the highest priority key. See Note [Plumbing Keys Through The + // Dispatcher] + template + Return redispatch( + const TypedOperatorHandle& op, + DispatchKeySet currentDispatchKeySet, + Args... args) const; + + // Invoke an operator via the boxed calling convention using an IValue stack + void callBoxed(const OperatorHandle& op, Stack* stack) const; + void callBoxedForDispatchKey( + const OperatorHandle& op, + DispatchKey dk, + Stack* stack) const; + + // TODO: This will only be useful if we write a backend fallback that plumbs + // dispatch keys (currently there are none) See Note [Plumbing Keys Through + // The Dispatcher] + void redispatchBoxed( + const OperatorHandle& op, + DispatchKeySet dispatchKeySet, + Stack* stack) const; + + bool hasBackendFallbackForDispatchKey(DispatchKey dk) { + auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk); + if (dispatch_ix < 0) + return false; + return backendFallbackKernels_[dispatch_ix].kernel.isValid(); + } + + // Used by torchdeploy/multipy for multiple interpreters racing. + void waitForDef(const FunctionSchema& schema); + void waitForImpl( + const OperatorName& op_name, + std::optional dispatch_key); + + // ------------------------------------------------------------------------ + // + // Performing registrations (NON user public; use op_registration) + // + // ------------------------------------------------------------------------ + + /** + * Register a new operator schema. + * + * If a schema with the same operator name and overload name already exists, + * this function will check that both schemas are exactly identical. + */ + RegistrationHandleRAII registerDef( + FunctionSchema schema, + std::string debug, + std::vector tags = {}); + + /** + * Register a kernel to the dispatch table for an operator. + * If dispatch_key is nullopt, then this registers a fallback kernel. + * + * @return A RAII object that manages the lifetime of the registration. + * Once that object is destructed, the kernel will be deregistered. + */ + // NB: steals the inferred function schema, as we may need to hold on to + // it for a bit until the real schema turns up + RegistrationHandleRAII registerImpl( + OperatorName op_name, + std::optional dispatch_key, + KernelFunction kernel, + std::optional cpp_signature, + std::unique_ptr inferred_function_schema, + std::string debug); + + /** + * Given an operator, tells the Dispatcher that we have implemented a fake + * impl for this op in the given Python module. Call this a "pystub". + */ + RegistrationHandleRAII registerPythonModule( + const OperatorName& op_name, + const char* pymodule, + const char* context); + + /** + * Given an operator, throws if we have a pystub. + */ + void throwIfHasPythonModule(OperatorName op_name); + + std::optional> getPyStub( + OperatorName op_name); + + /** + * Register a new operator by name. + */ + RegistrationHandleRAII registerName(OperatorName op_name); + + /** + * Register a fallback kernel for a backend. + * If an operator is called but there is no concrete kernel for the dispatch + * key of the given operator arguments, it will check if there is such a + * fallback kernel for the given dispatch key and, if yes, call that one. + */ + RegistrationHandleRAII registerFallback( + DispatchKey dispatch_key, + KernelFunction kernel, + std::string debug); + + /** + * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend + * API. These invocations are only permitted once per program, so we raise + * an error if this is called again for the same namespace. + */ + RegistrationHandleRAII registerLibrary(std::string ns, std::string debug); + + // ------------------------------------------------------------------------ + // + // Listeners on registrations + // + // ------------------------------------------------------------------------ + + /** + * Add a listener that gets called whenever a new op is registered or an + * existing op is deregistered. Immediately after registering, this listener + * gets called for all previously registered ops, so it can be used to keep + * track of ops registered with this dispatcher. + */ + RegistrationHandleRAII addRegistrationListener( + std::unique_ptr listener); + + void checkInvariants() const; + + // + // ------------------------------------------------------------------------ + // + // Assertions + // + // ------------------------------------------------------------------------ + + /** + * For testing purposes. + * Returns a list of all operators that were created through calls to + * registerImpl(), without any corresponding calls to registerDef(). After + * static initialization is done this is almost certainly a bug, as the + * created OperatorHandle won't have any schema associated with it and users + * calling the op through the dispatcher won't be able to access it + * + * Note that we cannot enforce this invariant "as we go" during static + * initialization, due to undefined static initialization order- we have no + * guarantees over the order in which .def() and .impl() calls are registered + * in the dispatcher at static initialization time. So this function should + * only be called after static initialization. + */ + std::vector findDanglingImpls() const; + + /** + * Useful for inspecting global Dispatcher registration state. + * Returns the names of all operators with a kernel registered for the + * specified DispatchKey. If no DispatchKey is specified, it returns all + * registered operators. + */ + std::vector getRegistrationsForDispatchKey( + std::optional k) const; + + private: + Dispatcher(); + + static int64_t sequenceNumberForRunningRecordFunction( + DispatchKey dispatchKey, + DispatchKeySet dispatchKeySet); + static void runRecordFunction( + at::RecordFunction& guard, + at::RecordFunction::schema_ref_t schema_ref, + DispatchKey dispatchKey, + DispatchKeySet dispatchKeySet); + static void runRecordFunction( + at::RecordFunction& guard, + at::RecordFunction::schema_ref_t schema_ref, + DispatchKey dispatchKey, + DispatchKeySet dispatchKeySet, + c10::ArrayRef args); + +#ifdef FBCODE_CAFFE2 + static bool profilingOperatorEvents(); + static void fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref); + static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref); +#endif // FBCODE_CAFFE2 + + OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema); + OperatorHandle findOrRegisterName_(const OperatorName& op_name); + + void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name); + void deregisterImpl_( + const OperatorHandle& op, + const OperatorName& op_name, + std::optional dispatch_key, + impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle); + void deregisterName_(const OperatorHandle& op, const OperatorName& op_name); + void deregisterFallback_(DispatchKey dispatchKey); + void deregisterLibrary_(const std::string& ns); + void cleanup(const OperatorHandle& op, const OperatorName& op_name); + void checkSchemaCompatibility( + const OperatorHandle& op, + const FunctionSchema& schema, + const std::string& debug); + + std::list operators_; +#if !defined(C10_MOBILE) + LeftRight> + operatorLookupTable_; +#else + RWSafeLeftRightWrapper> + operatorLookupTable_; +#endif + // Map from namespace to debug string (saying, e.g., where the library was + // defined) + ska::flat_hash_map libraries_; + + std::array + backendFallbackKernels_; + + std::unique_ptr listeners_; + + // This condition variable gets notified whenever we add a new def/impl to the + // dispatch table. This is primarily used by multipy/torchdeploy, when + // we have multiple interpreters trying to register to the dispatch table. + // In this situation, whenever the non-primary interpreter would have tried + // to register to the dispatch table, instead it will check to see if the + // expected registration has already been made, and if it hasn't, wait on + // this condition variable to see if it was just racing with the primary + // interpreter. + // + // We expect it to be rare for there to be any waiters on this condition + // variable. This is mostly just to help give better diagnostics if + // something goes horribly wrong + std::condition_variable cond_var_; + + // Protect concurrent access to the dispatcher. We store this in a + // `shared_ptr` as we return callbacks that call back into dispatcher methods, + // and we need to be able to handle and guard against the event when the + // `Dispatcher` has been destroyed before the callbacks fire. + std::shared_ptr guard_; +}; + +/** + * This is a handle to an operator schema registered with the dispatcher. + * This handle can be used to register kernels with the dispatcher or + * to lookup a kernel for a certain set of arguments. + */ +class TORCH_API OperatorHandle { + template + friend struct std::hash; + + public: + OperatorHandle(OperatorHandle&&) noexcept = default; + OperatorHandle& operator=(OperatorHandle&&) noexcept = default; + OperatorHandle(const OperatorHandle&) = default; + OperatorHandle& operator=(const OperatorHandle&) = default; + // NOLINTNEXTLINE(performance-trivially-destructible) + ~OperatorHandle(); + + const OperatorName& operator_name() const { + return operatorDef_->op.operator_name(); + } + + bool hasSchema() const { + return operatorDef_->op.hasSchema(); + } + + const FunctionSchema& schema() const { + return operatorDef_->op.schema(); + } + + const std::string& debug() const { + return operatorDef_->op.debug(); + } + + std::string dumpState() const { + return operatorDef_->op.dumpState(); + } + + bool hasKernelForDispatchKey(DispatchKey k) const { + return operatorDef_->op.hasKernelForDispatchKey(k); + } + + bool isKernelFallthroughKernel(DispatchKey k) const { + return operatorDef_->op.kernelForDispatchKey(k).isFallthrough(); + } + + bool hasKernelForAnyDispatchKey(DispatchKeySet k) const { + return operatorDef_->op.hasKernelForAnyDispatchKey(k); + } + + bool hasComputedKernelForDispatchKey(DispatchKey k) const { + return operatorDef_->op.hasComputedKernelForDispatchKey(k); + } + + std::string dumpComputedTable() const { + return operatorDef_->op.dumpComputedTable(); + } + + void checkInvariants() const { + return operatorDef_->op.checkInvariants(); + } + + c10::ArrayRef getTags() const { + return operatorDef_->op.getTags(); + } + + void setReportErrorCallback_(std::unique_ptr callback) { + operatorDef_->op.setReportErrorCallback_(std::move(callback)); + } + + bool hasTag(const at::Tag& tag) const { + for (const auto& tag_ : getTags()) { + if (tag == tag_) { + return true; + } + } + return false; + } + + template + TypedOperatorHandle typed() const { + // NB: This assert is not 100% sound: you can retrieve a typed() operator + // handle prior to ANY C++ signature being registered on the operator + // and the check will say everything is OK (at which point you can then + // smuggle in a kernel that is typed incorrectly). For everything + // in core library this won't happen, because all the static registrations + // will be done by the time a typed() handle is acquired. +#if !defined C10_MOBILE + operatorDef_->op.assertSignatureIsCorrect(); + if (fn_has_symint::value) { + operatorDef_->op.assertSignatureIsCorrect< + typename fn_remove_symint::type>(); + } +#endif + return TypedOperatorHandle(operatorIterator_); + } + + void callBoxed(Stack* stack) const { + c10::Dispatcher::singleton().callBoxed(*this, stack); + } + + void callBoxed(Stack& stack) const { + callBoxed(&stack); + } + + void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const { + c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack); + } + + void redispatchBoxed(DispatchKeySet ks, Stack* stack) const { + c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack); + } + + template + PyObject* getPythonOp( + c10::impl::PyInterpreter* self_interpreter, + F slow_accessor) const { + return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor); + } + + bool operator==(const OperatorHandle& other) const { + return operatorDef_ == other.operatorDef_; + } + + bool operator!=(const OperatorHandle& other) const { + return operatorDef_ != other.operatorDef_; + } + + private: + explicit OperatorHandle( + std::list::iterator operatorIterator) + : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {} + friend class Dispatcher; + template + friend class TypedOperatorHandle; + + // Storing a direct pointer to the OperatorDef even though we + // already have the iterator saves an instruction in the critical + // dispatch path. The iterator is effectively a + // pointer-to-std::list-node, and (at least in libstdc++'s + // implementation) the element is at an offset 16 bytes from that, + // because the prev/next pointers come first in the list node + // struct. So, an add instruction would be necessary to convert from the + // iterator to an OperatorDef*. + Dispatcher::OperatorDef* operatorDef_; + + // We need to store this iterator in order to make + // Dispatcher::cleanup() fast -- it runs a lot on program + // termination (and presuambly library unloading). + std::list::iterator operatorIterator_; +}; + +/** + * This is a handle to an operator schema registered with the dispatcher. + * It holds the same information as an OperatorHandle, but it is templated + * on the operator arguments and allows calling the operator in an + * unboxed way. + */ +template +class TypedOperatorHandle final { + static_assert( + guts::false_t(), + "FuncType in OperatorHandle::typed was not a valid function type"); +}; +template +class TypedOperatorHandle final : public OperatorHandle { + public: + TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default; + TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default; + TypedOperatorHandle(const TypedOperatorHandle&) = default; + TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default; + + // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use + // && + C10_ALWAYS_INLINE Return call(Args... args) const { + return c10::Dispatcher::singleton().call( + *this, std::forward(args)...); + } + + // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use + // && + C10_ALWAYS_INLINE Return + redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const { + return c10::Dispatcher::singleton().redispatch( + *this, currentDispatchKeySet, std::forward(args)...); + } + + private: + explicit TypedOperatorHandle( + std::list::iterator operatorIterator) + : OperatorHandle(operatorIterator) {} + friend class OperatorHandle; +}; + +namespace detail { +template +inline void unused_arg_(const Args&...) {} + +// CaptureKernelCall is intended to capture return values from Dispatcher +// unboxed kernel calls. A record function may request to get outputs from the +// kernel calls. For boxed kernels, it's straightforward, the returned values +// are in the stack object. The stack can be passed to record functions. For +// unboxed kernels, we need to handle different kinds of return values, cache +// them temporarily, then release the values for the actual function call +// return. +template +struct CaptureKernelCall { + template + CaptureKernelCall( + const F& kernel, + const TypedOperatorHandle& op, + const DispatchKeySet& dispatchKeySet, + Args&&... args) + // Calls the kernel and capture the result in output_. + : output_{kernel.template call( + op, + dispatchKeySet, + std::forward(args)...)} {} + // Wraps the return values in a Stack. + Stack getOutputs() { + Stack stack; + impl::push_outputs::copy(output_, &stack); + return stack; + } + // Since we are returning the output_, we don't expect the output_ to be used + // afterward. Copy elision and RVO do not apply to class data members. Using + // move semantic to avoid copies when possible. + ReturnType release() && { + return std::move(output_); + } + + private: + ReturnType output_; +}; + +// Handle the lvalue reference differently since it should not be moved. +template <> +inline at::Tensor& CaptureKernelCall::release() && { + return output_; +} + +// Handle case where the kernel returns void. +template <> +struct CaptureKernelCall { + template + CaptureKernelCall( + const F& kernel, + const TypedOperatorHandle& op, + const DispatchKeySet& dispatchKeySet, + Args&&... args) { + // Calling the kernel and no need to capture void. + kernel.template call( + op, dispatchKeySet, std::forward(args)...); + } + Stack getOutputs() { + return Stack(); + } + void release() && {} +}; + +TORCH_API void _print_dispatch_trace( + const std::string& label, + const std::string& op_name, + const DispatchKeySet& dispatchKeySet); + +} // namespace detail + +// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && +template +inline Return Dispatcher::callWithDispatchKeySlowPath( + const TypedOperatorHandle& op, + at::StepCallbacks& stepCallbacks, + DispatchKeySet dispatchKeySet, + const KernelFunction& kernel, + Args... args) { + // If callbacks need inputs, we box the arguments and pass them to the guard. + // Note: For perf reasons we wouldn't want to prematurely box the arguments. + at::RecordFunction guard(std::move(stepCallbacks)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved()); + auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); + auto& schema = op.schema(); + auto schema_ref = std::reference_wrapper(schema); + constexpr auto num_boxed_args = impl::boxed_size(); + if constexpr (num_boxed_args != 0) { + if (guard.needsInputs()) { + // If we used std::array here, we would + // have to spend time default constructing the IValues in + // boxedArgs. aligned_storage has no such requirement. + // NOLINTNEXTLINE(*array*) + alignas(IValue) std::byte boxedArgs[num_boxed_args * sizeof(IValue)]; + // For debugging only; could be removed (but the compiler will do + // that for us and it's nice to have the extra assurance of + // correctness from our debug builds). + IValue* boxedArgsPtr = reinterpret_cast(boxedArgs); + impl::boxArgsToStack(boxedArgsPtr, args...); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + reinterpret_cast(boxedArgsPtr) == + boxedArgs + num_boxed_args * sizeof(IValue)); + // I don't *think* we need std::launder here, because IValue has + // no subclasses and no const or reference fields. + runRecordFunction( + guard, + schema_ref, + dispatchKey, + dispatchKeySet, + c10::ArrayRef( + reinterpret_cast(boxedArgs), num_boxed_args)); + boxedArgsPtr = reinterpret_cast(boxedArgs); + for (size_t ii = 0; ii < num_boxed_args; ++ii) { + (boxedArgsPtr + ii)->~IValue(); + } + } else { + runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet); + } + } else { + runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet); + } + + if (C10_UNLIKELY(guard.needsOutputs())) { + // Calls the kernel and capture the output temporarily to pass to + // RecordFunction. + detail::CaptureKernelCall captureKernelCall( + kernel, op, dispatchKeySet, std::forward(args)...); + guard.setOutputs(captureKernelCall.getOutputs()); + // Releases the captured output to return to caller. + return std::move(captureKernelCall).release(); + } + + // keeping the guard alive while executing the kernel + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); +} + +// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && +template +C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call( + const TypedOperatorHandle& op, + Args... args) const { + auto dispatchKeySet = + op.operatorDef_->op.dispatchKeyExtractor() + .template getDispatchKeySetUnboxed(args...); +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[call]", toString(op.operator_name()), dispatchKeySet); + } +#endif + const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); +#ifndef PYTORCH_DISABLE_PER_OP_PROFILING + auto step_callbacks = + at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY( + step_callbacks.has_value() && op.operatorDef_->op.isObserved())) { + return callWithDispatchKeySlowPath( + op, + *step_callbacks, + dispatchKeySet, + kernel, + std::forward(args)...); + } +#endif // PYTORCH_DISABLE_PER_OP_PROFILING + +#ifdef FBCODE_CAFFE2 + if (profilingOperatorEvents()) { + struct FireOpRAII { + FireOpRAII(at::RecordFunction::schema_ref_t schema_ref) + : schema_ref_(schema_ref) { + fireOpStartUSDT(schema_ref); + } + ~FireOpRAII() { + fireOpEndUSDT(schema_ref_); + } + at::RecordFunction::schema_ref_t schema_ref_; + } event(op.schema()); + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); + } else { + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); + } +#else + return kernel.template call( + op, dispatchKeySet, std::forward(args)...); +#endif // FBCODE_CAFFE2 +} + +// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && +template +inline Return Dispatcher::redispatch( + const TypedOperatorHandle& op, + DispatchKeySet currentDispatchKeySet, + Args... args) const { + // do not use RecordFunction on redispatch +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[redispatch]", toString(op.operator_name()), currentDispatchKeySet); + } +#endif + const KernelFunction& kernel = + op.operatorDef_->op.lookup(currentDispatchKeySet); + return kernel.template call( + op, currentDispatchKeySet, std::forward(args)...); +} + +inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) + const { + // note: this doesn't need the mutex because write operations on the list keep + // iterators intact. + const auto& entry = op.operatorDef_->op; + auto dispatchKeySet = + entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[callBoxed]", toString(op.operator_name()), dispatchKeySet); + } +#endif + const auto& kernel = entry.lookup(dispatchKeySet); +#ifndef PYTORCH_DISABLE_PER_OP_PROFILING + auto step_callbacks = + at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) { + at::RecordFunction guard(std::move(*step_callbacks)); + auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); + auto& schema = op.schema(); + auto schema_ref = std::reference_wrapper(schema); + guard.needsInputs() + ? runRecordFunction( + guard, + schema_ref, + dispatchKey, + dispatchKeySet, + c10::ArrayRef(stack->data(), stack->size())) + : runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet); + + // keeping the guard alive while executing the kernel + kernel.callBoxed(op, dispatchKeySet, stack); + + if (C10_UNLIKELY(guard.needsOutputs())) { + guard.setOutputs(*stack); + } + return; + } +#endif // PYTORCH_DISABLE_PER_OP_PROFILING + kernel.callBoxed(op, dispatchKeySet, stack); +} + +// NB: this doesn't count as a "true" dispatcher jump, so no instrumentation +inline void Dispatcher::callBoxedForDispatchKey( + const OperatorHandle& op, + DispatchKey dk, + Stack* stack) const { + // note: this doesn't need the mutex because write operations on the list keep + // iterators intact. + const auto& entry = op.operatorDef_->op; + // We still compute this as we're obligated to pass it on to the internal + // kernel, if it is a boxed fallback + auto dispatchKeySet = + entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); + const auto& kernel = ([&]() { + if (op.hasKernelForDispatchKey(dk)) { + return entry.kernelForDispatchKey(dk); + } else { + auto idx = getDispatchTableIndexForDispatchKey(dk); + TORCH_INTERNAL_ASSERT(idx >= 0); + return backendFallbackKernels_[idx].kernel; + } + })(); + kernel.callBoxed(op, dispatchKeySet, stack); +} + +inline void Dispatcher::redispatchBoxed( + const OperatorHandle& op, + DispatchKeySet dispatchKeySet, + Stack* stack) const { + // note: this doesn't need the mutex because write operations on the list keep + // iterators intact. + const auto& entry = op.operatorDef_->op; +#if defined(HAS_TORCH_SHOW_DISPATCH_TRACE) || !defined(NDEBUG) + DispatchTraceNestingGuard debug_guard; + if (show_dispatch_trace()) { + detail::_print_dispatch_trace( + "[redispatchBoxed]", toString(op.operator_name()), dispatchKeySet); + } +#endif + const auto& kernel = entry.lookup(dispatchKeySet); + return kernel.callBoxed(op, dispatchKeySet, stack); +} + +} // namespace c10 + +namespace std { + +template <> +struct hash { + size_t operator()(const c10::OperatorHandle& op) const noexcept { + return std::hash{}(static_cast(op.operatorDef_)); + } +}; + +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h new file mode 100644 index 0000000000000000000000000000000000000000..1741171fbf00412647178b2210071cee36928e54 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +struct TORCH_API ObservedOperators { + ObservedOperators() = delete; + + static bool isObserved(const OperatorName& name); + + static std::unordered_set& getUnobservedOperatorList(); +}; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h new file mode 100644 index 0000000000000000000000000000000000000000..83200ff9c94ffcf4325a3fc8cb9844c699f9c1b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h @@ -0,0 +1,335 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#ifdef C10_MOBILE +#define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY +#endif + +namespace c10 { + +class Dispatcher; + +namespace impl { + +// This data structure represents a kernel that was registered to us from a +// user. Unlike KernelFunction, AnnotatedKernel contains some extra metadata +// about the kernel that isn't necessary for actual dispatching (this is why +// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for +// giving good error messages. +struct AnnotatedKernel final { + AnnotatedKernel( + KernelFunction k, + std::unique_ptr s, + std::string d) + : kernel(std::move(k)), + inferred_function_schema(std::move(s)), + debug(std::move(d)) {} + AnnotatedKernel() = default; + KernelFunction kernel; + std::unique_ptr inferred_function_schema; + // A little debug string to help us identify the kernel in question. + // Most importantly it records the TORCH_LIBRARY block that did the + // registration. + std::string debug; +}; + +// This data structure represents operator schema, with metadata specifying +// where the registration of this schema occurred +struct AnnotatedSchema final { + AnnotatedSchema(FunctionSchema s, std::string d) + : schema(std::move(s)), debug(std::move(d)) {} + FunctionSchema schema; + std::string debug; +}; + +// Internal data structure that records information about a specific operator. +// It's not part of the public API; typically, users will interact with +// OperatorHandle instead. +// +// Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher +// lock (this is important because some methods in OperatorEntry access +// dispatcher state) +class TORCH_API OperatorEntry final { + public: + explicit OperatorEntry(OperatorName&& operator_name); + + OperatorEntry(const OperatorEntry&) = delete; + OperatorEntry(OperatorEntry&&) noexcept = delete; + OperatorEntry& operator=(const OperatorEntry&) = delete; + OperatorEntry& operator=(OperatorEntry&&) noexcept = delete; + + const FunctionSchema& schema() const { + TORCH_INTERNAL_ASSERT( + schema_.has_value(), + "Tried to access the schema for ", + name_, + " which doesn't have a schema registered yet"); + return schema_->schema; + } + const std::string& debug() const { + TORCH_INTERNAL_ASSERT(schema_.has_value()); + return schema_->debug; + } + bool hasSchema() const { + return schema_.has_value(); + } + + bool isObserved() const { + return is_observed_; + } + + // We may allocate an OperatorEntry for an operator even when we don't + // have a schema. When we receive the schema registration, we post + // facto register a schema. + // + // NB: registerSchema/deregisterSchema are not idempotent; if you + // attempt to register a schema when one is already present or vice + // versa that is an error. (Refcounting for the registrations is + // handled in the OperatorHandle in Dispatcher) + void registerSchema( + FunctionSchema&&, + std::string&& debug, + std::vector tags = {}); + void deregisterSchema(); + + const OperatorName& operator_name() const { + return name_; + } + +#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY + using AnnotatedKernelContainer = std::array; +#else + using AnnotatedKernelContainer = std::list; +#endif + using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator; + + // Why are kernels and fallback asymmetric? It has to do with ownership. + // Kernels and the computed dispatch tables for them are canonically + // owned by OperatorEntry, but backend fallbacks are specified once + // and apply for all operators, so they should be owned by Dispatcher. + // However, the registration of a backend fallback affects the + // state of the computed dispatch table, so when a backend fallback + // is updated, we need to update the operator tables too. Thus, + // registerKernel is the mechanism by which we give kernels to + // operator entry to own (and update dispatch table), but we only + // need a non-owning mechanism to update fallback. + + // Precondition: Dispatcher::mutex_ is held + // Postcondition: caller is responsible for disposing of the kernel + AnnotatedKernelContainerIterator registerKernel( + const Dispatcher& dispatcher, + std::optional dispatch_key, + KernelFunction kernel, + std::optional cpp_signature, + std::unique_ptr inferred_function_schema, + std::string debug); + + // Precondition: Dispatcher::mutex_ is held + void deregisterKernel_( + const Dispatcher& dispatcher, + std::optional dispatch_key, + AnnotatedKernelContainerIterator kernel); + + // Precondition: Dispatcher::mutex_ is held + void updateFallback(const Dispatcher& dispatcher, DispatchKey dispatch_key); + + // Precondition: Dispatcher::mutex_ is held + void updateSchemaAliasAnalysis(AliasAnalysisKind a) { + TORCH_INTERNAL_ASSERT(schema_.has_value()); + schema_->schema.setAliasAnalysis(a); + } + + std::string dumpComputedTable() const; + std::string dumpState() const; + void checkInvariants() const; + + const DispatchKeyExtractor& dispatchKeyExtractor() const { + return dispatchKeyExtractor_; + } + + // Asserts that the given FuncType is correct for calling this operator in an + // unboxed way. + template + inline void assertSignatureIsCorrect() { + assertSignatureIsCorrect( + CppSignature::make(), fn_has_symint::value); + } + + void assertSignatureIsCorrect( + const CppSignature& call_signature, + bool has_symint) const; + + [[noreturn]] void reportError(DispatchKey dispatchKey) const; + + const KernelFunction& lookup(DispatchKeySet ks) const { + const auto idx = ks.getDispatchTableIndexForDispatchKeySet(); + if (C10_UNLIKELY(idx == -1)) { + reportError(ks.highestPriorityTypeId()); + } + const auto& kernel = dispatchTable_[idx]; + // A valid kernel *always* has a boxed kernel and *may* have an + // unboxed kernel. However, we typically do unboxed calls in at:: + // APIs, where the kernel 1) will very likely be valid and 2) + // should have an unboxed kernel. Checking the unboxed kernel + // first will allow us to avoid touching the boxed kernel at all + // in the common case. + if (C10_UNLIKELY(!kernel.isValidUnboxed())) { + if (!kernel.isValid()) { + reportError(ks.highestPriorityTypeId()); + } + } + return kernel; + } + + std::string listAllDispatchKeys() const; + + // Returns true if kernel_ has entry for any key in ks. + // + // Invariant: There are no alias keys in the passed-in dispatch key set. + // Note [No Alias Keys in DispatchKeySet] + // Alias keys should be checked using `hasKernelForDispatchKey` + // Alias keys shouldn't go inside of a DispatchKeySet, since they can + // technically have a value > 63 (causing overflow). + bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const; + // Returns true if kernel_ has entry for a particular key. + bool hasKernelForDispatchKey(DispatchKey k) const; + // Retrieves the kernel entry at a particular key. Symmetric with + // hasKernelForDispatchKey. To get the AnnotatedKernel, see + // getKernelForDispatchKey (private) + const KernelFunction& kernelForDispatchKey(DispatchKey k) const; + // Returns true if the "computed table" has an entry for a particular key. + bool hasComputedKernelForDispatchKey(DispatchKey k) const; + // Returns all the operator tags added at the time of registration + const std::vector& getTags() const; + void setReportErrorCallback_(std::unique_ptr callback); + + template + PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) + const { + return py_cache_.ptr_or(self_interpreter, slow_accessor); + } + + private: + OperatorName name_; + std::optional schema_; +#ifndef C10_MOBILE + std::vector tags_; +#endif + std::array dispatchTable_; + DispatchKeyExtractor dispatchKeyExtractor_; + // Pointer to the torch.ops.ns.op.overload object for speed + c10::PyHandleCache py_cache_; + + // kernels_ stores all registered kernels for the corresponding dispatch key + // and catchAllKernels_ stores the catch-all kernels. + // If an operator library gets loaded that overwrites an already existing + // kernel, both kernels will be in that list but only the newer one will be in + // dispatchTable. If any of the kernels go away (say the library gets + // unloaded), we remove the kernel from this list and update the + // dispatchTable if necessary. + // Kernels in the list are ordered by registration time descendingly, + // newer registrations are before older registrations. + // We do not combine dispatchTable and kernels into one hash map because + // kernels is a larger data structure and accessed quite infrequently + // while dispatchTable is accessed often and should be kept small to fit + // into CPU caches. + // Invariants: + // - dispatchTable[dispatch_key] == kernels_[dispatch_key].front() + // - dispatchTable[dispatch_key] does not exist if and only if + // kernels_[dispatch_key] does not exist + // - If kernels_[dispatch_key] exists, then it has elements. + // It is never an empty list. + // + // Why do we do that? + // ----- + // We mostly do this to enable Jupyter notebooks where a cell registering + // a kernel could be executed multiple times and the later execution + // should overwrite the earlier one. Note that this still fails when the + // function schema changed between the executions, but it works as long + // as the function schema didn't change. A better solution would be to + // unload the old extension library from the Jupyter cell when the cell is + // re-executed and then only allow one kernel here, i.e. error if a kernel + // is already registered, but that's a lot of effort to implement and + // currently not high-pri. + ska::flat_hash_map< + DispatchKey, +#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY + // On mobile, we needn't worry about Jupyter notebooks. + std::array +#else + std::list +#endif + > + kernels_; + + const AnnotatedKernel& missingKernel() const; + const AnnotatedKernel& ambiguousAutogradOtherKernel() const; + + // cpp_signature_ stores function signature if any of + // the kernels was created in a way that allowed us to know the function + // signature (i.e. by supplying an unboxed C++ kernel function). + // If this is set, it will be used to check that future kernel + // registrations match and it will be used in unboxed function calls + // to verify their arguments against the known function signature. + struct CppSignatureWithDebug { + CppSignature signature; + std::string debug; + std::optional dispatch_key; + }; + std::optional cpp_signature_; + std::optional sym_cpp_signature_; + + // A Python custom error handler for OperatorEntry::reportError + std::unique_ptr report_error_callback_; + + // Whether this operator needs to be observed with RecordFunction + const bool is_observed_; + + [[noreturn]] void reportSignatureError( + const CppSignature& call_signature, + const CppSignatureWithDebug& saved_signature) const; + const KernelFunction& computeDispatchTableEntry( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key) const; + std::pair + computeDispatchTableEntryWithDebug( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key) const; + // This function re-establishes the invariant that dispatchTable + // contains the front element from the kernels list for a given runtime + // dispatch key. + void updateDispatchTableEntry_( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key); + // Like above, but also handles alias dispatch keys. + void updateDispatchTable_( + const c10::Dispatcher& dispatcher, + DispatchKey dispatch_key); + // Like above, but for ALL entries in the dispatch table. + void updateDispatchTableFull_(const c10::Dispatcher& dispatcher); + // Retrieves a pointer to AnnotatedKernel at + // kernels_.at(dispatch_key).front(). + const AnnotatedKernel* getKernelForDispatchKey( + DispatchKey dispatch_key) const; +}; + +} // namespace impl +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..d66686c1bb4690a3b39d1ed959167e0a375d0319 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace c10 { + +enum class AliasAnalysisKind : uint8_t { + INTERNAL_SPECIAL_CASE, + CONSERVATIVE, // The most conservative alias analysis type, assumes + // side-effects. This is the default analysis. + FROM_SCHEMA, + PURE_FUNCTION +}; + +#if !defined(_MSC_VER) +constexpr // Our current MSVC version has a bug that doesn't allow this to be + // constexpr. +#endif + inline const char* + toString(AliasAnalysisKind aliasAnalysisKind) { + return (aliasAnalysisKind == AliasAnalysisKind::CONSERVATIVE) ? "CONSERVATIVE" + : (aliasAnalysisKind == AliasAnalysisKind::FROM_SCHEMA) ? "FROM_SCHEMA" + : (aliasAnalysisKind == AliasAnalysisKind::PURE_FUNCTION) + ? "PURE_FUNCTION" + : (aliasAnalysisKind == AliasAnalysisKind::INTERNAL_SPECIAL_CASE) + ? "INTERNAL_SPECIAL_CASE" + : "UNKNOWN"; +} + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h new file mode 100644 index 0000000000000000000000000000000000000000..a5a88aafed631a88efa35f8526ace8d8b1f293fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace c10 { + +class RegistrationHandleRAII final { + public: + explicit RegistrationHandleRAII(std::function onDestruction) + : onDestruction_(std::move(onDestruction)) {} + + ~RegistrationHandleRAII() { + if (onDestruction_) { + onDestruction_(); + } + } + + RegistrationHandleRAII(const RegistrationHandleRAII&) = delete; + RegistrationHandleRAII& operator=(const RegistrationHandleRAII&) = delete; + + RegistrationHandleRAII(RegistrationHandleRAII&& rhs) noexcept + : onDestruction_(std::move(rhs.onDestruction_)) { + rhs.onDestruction_ = nullptr; + } + + RegistrationHandleRAII& operator=(RegistrationHandleRAII&& rhs) noexcept { + onDestruction_ = std::move(rhs.onDestruction_); + rhs.onDestruction_ = nullptr; + return *this; + } + + private: + std::function onDestruction_; +}; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dynamic_type.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dynamic_type.h new file mode 100644 index 0000000000000000000000000000000000000000..b33e7ce0c5495f7c25ca4306e5abe536cb745977 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/dynamic_type.h @@ -0,0 +1,246 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace c10 { + +using DynamicTypeBits = std::uint32_t; +#define DYNAMIC_TYPE_BIT(x) (1u << x) + +constexpr DynamicTypeBits kDynamicCovariantTypeBit = DYNAMIC_TYPE_BIT(31); +constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30); + +constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1); +constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3); +constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4); +constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5); +constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7); +constexpr DynamicTypeBits kDynamicTupleTypeBit = DYNAMIC_TYPE_BIT(8); +constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10); + +#define FORALL_DYNAMIC_TYPES(_) \ + _(Tensor, DYNAMIC_TYPE_BIT(0), 1) \ + _(None, kDynamicNoneTypeBit, 1) \ + _(Bool, DYNAMIC_TYPE_BIT(2), 1) \ + _(Int, kDynamicIntTypeBit, 1) \ + _(Float, kDynamicFloatTypeBit, 1) \ + _(Complex, kDynamicComplexTypeBit, 1) \ + _(Number, \ + (kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \ + 1) \ + _(String, DYNAMIC_TYPE_BIT(6), 1) \ + _(List, kDynamicListTypeBit, 0) \ + _(Tuple, (kDynamicTupleTypeBit | kDynamicCovariantTypeBit), 0) \ + _(Dict, DYNAMIC_TYPE_BIT(9), 0) \ + _(Class, kDynamicClassTypeBit, 0) \ + _(Optional, \ + (DYNAMIC_TYPE_BIT(11) | kDynamicNoneTypeBit | kDynamicCovariantTypeBit), \ + 0) \ + _(AnyList, (kDynamicListTypeBit | kDynamicAnyTypeBit), 1) \ + _(AnyTuple, \ + (kDynamicTupleTypeBit | kDynamicCovariantTypeBit | kDynamicAnyTypeBit), \ + 1) \ + _(DeviceObj, DYNAMIC_TYPE_BIT(12), 1) \ + _(StreamObj, DYNAMIC_TYPE_BIT(13), 1) \ + _(Capsule, DYNAMIC_TYPE_BIT(14), 1) \ + _(Generator, DYNAMIC_TYPE_BIT(15), 1) \ + _(Storage, DYNAMIC_TYPE_BIT(16), 1) \ + _(Var, DYNAMIC_TYPE_BIT(17), 0) \ + _(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \ + _(QScheme, DYNAMIC_TYPE_BIT(18), 1) \ + _(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \ + _(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \ + _(RRef, DYNAMIC_TYPE_BIT(21), 0) \ + _(Future, DYNAMIC_TYPE_BIT(22), 0) \ + _(Await, DYNAMIC_TYPE_BIT(23), 0) \ + _(Any, 0xffffffff, 1) + +#define FORALL_DYNAMIC_TYPES_FAKE(_) \ + _(ScalarType, kDynamicIntTypeBit, 1) \ + _(Layout, kDynamicIntTypeBit, 1) \ + _(SymInt, kDynamicIntTypeBit, 1) \ + _(MemoryFormat, kDynamicIntTypeBit, 1) + +#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type; + FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE) + FORALL_DYNAMIC_TYPES_FAKE(FORWARD_DECL_TYPE) +#undef FORWARD_DECL_TYPE + +class DynamicType; +using DynamicTypePtr = std::shared_ptr; + +/** + * DynamicType is designed as a low dependency type system for TorchScript. The + * existing JIT types are used for both compilation and runtime, which makes + * sense for server contexts because we often compile and run the model in + * the same process, however this doesn't hold for mobile devices where we + * always compiles a model ahead of time, therefore there will be dependencies + * which are not needed, but built with mobile runtime causing binary size + * bloat, by design. Every basic type like Int, Bool or String will bring their + * vtable, typeinfo, constructor, destructor and even more data from their + * specializations for STL types to the binary causing a long tail bloat. + * + * The core problem is about the complexity to implement and maintain a single + * type system for both analysis and execution purposes. Although they should + * have the exactly same semantics, in practice implement a unified abstraction + * adds conceptual and representational overhead for both sides of the world. + * + * To address the issues, DynamicType implements a minimal subset of JIT types + * and uses a generic algorithm to test all subtyping relations. To achieve + * this, we assign each dynamic type a single integer tag to represent its + * semantics. More specifically, a dynamic type is defined as a set of "control + * bits" and "data bits", where control bits describe the special behavior when + * testing a type and data bits map to identity of each nominal type. We use bit + * operations to perform all the tests. + * + * For example, a "covariant bit" is a control bit used to describe if a type + * is covariant, right now the most used one is tuple type, and in addition to + * the control bit, tuple type's data bit is the 8th bit from the LSB. Control + * bits start from MSB and data bits start from LSB. + * + * If two types are equal, then they are subtype of each other, also if the bits + * from one type tag is subset of the other tag, it automatically becomes a + * subtype of the other. This simplifies the subtyping logic a lot, and over the + * long term it is possible to adopt this scheme on the server side as well. + * Special cases can be added but they generally should not take too much code + * size. + * + * DynamicType may or may not inherit from c10::Type because it's not the core + * requirement of DynamicType to interface with existing JIT types, but we might + * want to inherit from c10::Type to reduce the migration cost. + */ +class DynamicType : public SharedType { + using ClassTypePtr = std::shared_ptr; + + /** + * A implementation detail to support NamedTuple. + */ + struct LabeledDynamicType { + std::optional label; + DynamicTypePtr ty; + explicit LabeledDynamicType(DynamicTypePtr t) : ty(std::move(t)) {} + + bool equals(const LabeledDynamicType& other) const; + bool isSubtypeOf(const LabeledDynamicType& other) const; + }; + + public: + // TODO Change Ptr to DynamicTypePtr when all migrations are done. + using Ptr = TypePtr; + using ElementType = DynamicType; + ~DynamicType() override; + + struct Arguments { + Arguments() = default; + Arguments(c10::ArrayRef); + Arguments(const std::vector&, c10::ArrayRef); + std::vector elems; + }; + + enum class Tag : DynamicTypeBits { +#define DYNAMIC_TYPE_ITEM(NAME, VAL, _) NAME = VAL, + FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_ITEM) + FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_ITEM) +#undef DYNAMIC_TYPE_ITEM + }; + + bool equals(const Type& rhs) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + std::string str() const override; + static const TypeKind Kind = TypeKind::DynamicType; + static TORCH_API DynamicTypePtr create(Type& ty); + + explicit DynamicType(Tag, Arguments); + explicit DynamicType(Tag, std::string_view, Arguments); + + DynamicType(DynamicType&& other) = delete; + DynamicType(const DynamicType&) = delete; + DynamicType& operator=(const DynamicType&) = delete; + DynamicType& operator=(DynamicType&&) = delete; + + TypePtr containedType(size_t) const override; + size_t containedTypeSize() const override; + Tag tag() const { + return tag_; + } + const std::optional& name() const { + return name_; + } + const Arguments& arguments() const { + return arguments_; + } + TORCH_API TypeKind dynamicKind() const; + + // Should be used only on the server side to restore static type information. +#ifndef C10_MOBILE + TORCH_API +#endif + TypePtr fallback() const; + + private: + bool symmetric() const override { + return false; + } + friend struct Type; + // NOTE: Here we are using SingletonOrSharedTypePtr to mean + // "original-type-because-it-was-actually-a-DynamicType or shared". + static SingletonOrSharedTypePtr create(const Type& ty); + DynamicType(const Type& other); + bool equals(const DynamicType& other) const; + + template + bool compareArguments(const DynamicType& other, const F& f) const { + if (arguments_.elems.size() != other.arguments_.elems.size()) { + return false; + } + for (size_t i = 0; i < arguments_.elems.size(); i++) { + if (!f(arguments_.elems[i], other.arguments_.elems[i])) { + return false; + } + } + return true; + } + + Tag tag_; + std::optional name_; + union { + Arguments arguments_; + ClassTypePtr class_; + }; +}; + +template +struct DynamicTypeTrait { + C10_NOINLINE static auto tagValue() { + TORCH_CHECK(false); + return DynamicType::Tag::Any; + } +}; + +namespace detail { +C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag); +} + +#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \ + template <> \ + struct TORCH_API DynamicTypeTrait { \ + C10_ERASE static auto tagValue() { \ + return DynamicType::Tag::NAME; \ + } \ + static constexpr bool isBaseType = IS_BASE_TYPE; \ + template \ + static std::enable_if_t getBaseType() { \ + static auto type = detail::makeBaseType(tagValue()); \ + return type; \ + } \ + }; // namespace c10 +FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE) +FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_TAG_VALUE) +#undef DYNAMIC_TYPE_TAG_VALUE + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/enum_tag.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/enum_tag.h new file mode 100644 index 0000000000000000000000000000000000000000..77ded3278a8addc70611195bff948f4cac533cab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/enum_tag.h @@ -0,0 +1,25 @@ +#pragma once + +// @generated by torchgen/gen.py from enum_tag.h + +namespace at { + // Enum of valid tags obtained from the entries in tags.yaml + enum class Tag { + core, + cudagraph_unsafe, + data_dependent_output, + dynamic_output_shape, + flexible_layout, + generated, + inplace_view, + maybe_aliasing_or_mutating, + needs_contiguous_strides, + needs_exact_strides, + needs_fixed_stride_order, + nondeterministic_bitwise, + nondeterministic_seeded, + pointwise, + pt2_compliant_tag, + view_copy + }; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/enum_type.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/enum_type.h new file mode 100644 index 0000000000000000000000000000000000000000..e292f58487fbdaea129765db3ff13f4dba3f5299 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/enum_type.h @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include + +namespace c10 { + +struct EnumType; +using EnumTypePtr = std::shared_ptr; +using EnumNameValue = std::pair; +struct TORCH_API EnumType : public NamedType { + friend struct Type; + static const TypeKind Kind = TypeKind::EnumType; + + static EnumTypePtr create( + const c10::QualifiedName& qualified_class_name, + TypePtr value, + std::vector enum_names_values, + std::weak_ptr<::torch::jit::CompilationUnit> cu) { + switch (value->kind()) { + case TypeKind::IntType: + case TypeKind::FloatType: + case TypeKind::StringType: + return EnumTypePtr(new EnumType( + qualified_class_name, + std::move(value), + std::move(enum_names_values), + std::move(cu))); + default: + TORCH_CHECK( + false, + "Cannot create Enum with value type '", + value->str(), + "', only int, float and string are supported"); + } + } + + std::string str() const override { + return "Enum<" + annotation_str() + ">"; + } + + std::string repr_str() const override { + return str(); + } + + const TypePtr& getValueType() const { + return value_type_; + } + + bool equals(const Type& rhs) const override { + if (auto* enum_rhs = rhs.castRaw()) { + return name().has_value() && name() == enum_rhs->name() && + *getValueType() == *(enum_rhs->getValueType()) && + this->compilation_unit() == enum_rhs->compilation_unit(); + } + return false; + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + std::shared_ptr compilation_unit() + const { + auto cu = cu_.lock(); + return cu; + } + + const QualifiedName& qualifiedClassName() const { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name().value(); + } + + at::ArrayRef containedTypes() const override { + return value_type_; + } + + const at::ArrayRef enumNamesValues() const { + return enum_names_values_; + } + + private: + EnumType( + c10::QualifiedName qualified_class_name, + TypePtr value_type, + std::vector enum_names_values, + std::weak_ptr cu) + : NamedType(TypeKind::EnumType, std::move(qualified_class_name)), + value_type_(std::move(value_type)), + enum_names_values_(std::move(enum_names_values)), + cu_(std::move(cu)) {} + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return qualifiedClassName().qualifiedName(); + } + + TypePtr value_type_; + std::vector enum_names_values_; + std::weak_ptr<::torch::jit::CompilationUnit> cu_; +}; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function.h new file mode 100644 index 0000000000000000000000000000000000000000..7e8a765a05abc1c26addac095316834ee026b4e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { +struct FunctionSchema; +} + +namespace at { +TORCH_API void launch(std::function func); +} + +namespace torch::jit { + +struct Graph; +struct Code; + +namespace mobile { +struct Code; +} + +using Stack = std::vector; +using Kwargs = std::unordered_map; +struct RecursiveMethodCallError : public std::exception {}; +using TaskLauncher = std::function)>; + +TORCH_API void preoptimizeGraph( + std::shared_ptr& graph, + bool disable_autocast = false); + +// A Function is a pure Graph with no implicit `self` object bound. +// It contains schema information and the executor that manages the +// execution of the function. Method is a wrapper around an +// underlying Function that also provides a `self` object. +struct TORCH_API Function { + Function() = default; + Function(const Function&) = default; + Function& operator=(const Function&) = default; + Function(Function&&) noexcept = default; + Function& operator=(Function&&) noexcept = default; + virtual std::string_view doc_string() const { + static constexpr std::string_view no_doc_string; + return no_doc_string; + } + + virtual bool isGraphFunction() const { + return false; + } + + virtual void run(Stack& stack) = 0; + + virtual c10::intrusive_ptr runAsync( + Stack& /*stack*/, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + [[maybe_unused]] TaskLauncher taskLauncher = at::launch) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); + return {}; + } + + at::IValue operator()(Stack stack, const Kwargs& kwargs = Kwargs()) { + getSchema().checkAndNormalizeInputs(stack, kwargs); + run(stack); + return stack.front(); + } + + virtual const c10::QualifiedName& qualname() const = 0; + + const std::string& name() const { + return qualname().name(); + } + + // if this isn't yet defined, run its method_creator function + virtual void ensure_defined() = 0; + + virtual const c10::FunctionSchema& getSchema() const = 0; + + virtual size_t num_inputs() const = 0; + + virtual Function& setSchema(c10::FunctionSchema schema) = 0; + + // call() defines how different interpreter implementations interacts with + // Function objects. Basically interpreters need to provide a callback to + // communicate to Functions what to do if provided a Code object. + // Alternatively we could design the signature to return an optional Code + // object, but that requires special handling the null case in interpreter + // and the fallback behavior is not well defined by interpreter but rather + // Function themselves, so a callback approach is more reasonable than + // returning values. + // If call() returns true, then callback completes successfully, otherwise + // call() returns false. + + // Overload for server interpreter, a bailout size is needed for graph + // executor. + virtual bool call( + Stack&, + std::optional, + c10::function_ref) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); + return false; + } + + // Overload for mobile interpreter. + virtual bool call(Stack&, c10::function_ref) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); + return false; + } + + virtual ~Function() = default; +}; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function_schema.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function_schema.h new file mode 100644 index 0000000000000000000000000000000000000000..c3e1520dc9868a7d7849169b895fc786a1f7d55b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function_schema.h @@ -0,0 +1,690 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// schema as used in the compiler for resolving function calls and reporting +// errors. These objects should be constructed from C10 schema once those +// are available. + +struct Argument; +struct FunctionSchema; + +using AliasTypeSet = std::vector; + +bool operator==(const Argument& lhs, const Argument& rhs); + +struct TORCH_API Argument { + Argument( + std::string name = "", + const TypePtr& type = nullptr, + std::optional N = std::nullopt, + std::optional default_value = std::nullopt, + bool kwarg_only = false, + std::optional alias_info = std::nullopt) + : Argument(std::move(name), type, type, N, std::move(default_value), kwarg_only, std::move(alias_info)) {} + + Argument( + std::string name, + TypePtr fake_type, + TypePtr real_type, + std::optional N = std::nullopt, + std::optional default_value = std::nullopt, + bool kwarg_only = false, + std::optional alias_info = std::nullopt) + : name_(std::move(name)), + type_(fake_type ? std::move(fake_type) : TensorType::get()), + real_type_(real_type ? std::move(real_type) : type_), + N_(N), + default_value_(std::move(default_value)), + alias_info_(alias_info ? std::make_unique(std::move(*alias_info)) : nullptr), + kwarg_only_(kwarg_only) { + // this is an softly-enforced invariant for out arguments. + bool is_alias = alias_info_ != nullptr && alias_info_->isWrite(); + is_out_ = kwarg_only_ && is_alias; + } + + Argument(Argument&& rhs) noexcept = default; + + Argument(const Argument& rhs) + : name_(rhs.name_), + type_(rhs.type_), + real_type_(rhs.real_type_), + N_(rhs.N_), + default_value_(rhs.default_value_), + alias_info_(rhs.alias_info_ ? std::make_unique(*rhs.alias_info_) : nullptr), + kwarg_only_(rhs.kwarg_only_), + is_out_(rhs.is_out_) {} + + Argument& operator=(Argument&& rhs) = default; + + Argument& operator=(const Argument& rhs) { + if (this != &rhs) { + name_ = rhs.name_; + type_ = rhs.type_; + real_type_ = rhs.real_type_; + N_ = rhs.N_; + default_value_ = rhs.default_value_; + alias_info_ = rhs.alias_info_ ? std::make_unique(*rhs.alias_info_) : nullptr; + kwarg_only_ = rhs.kwarg_only_; + is_out_ = rhs.is_out_; + } + return *this; + } + ~Argument() = default; + + const std::string& name() const { + return name_; + } + const TypePtr& type() const { + return type_; + } + // if type() is non-null, this is guaranteed to be non-null (if no real + // type was provided, this takes on type()'s value) + const TypePtr& real_type() const { + return real_type_; + } + const std::optional& N() const { + return N_; + } + const std::optional& default_value() const { + return default_value_; + } + bool kwarg_only() const { + return kwarg_only_; + } + + bool is_out() const { + return is_out_; + } + + [[nodiscard]] const AliasInfo* alias_info() const { + return alias_info_.get(); + } + + bool is_inferred_type() const { + bool is_inferred_type = false; + TORCH_INTERNAL_ASSERT(type_); + if (auto pt = type_->cast()) { + if (pt->isInferredType()) { + is_inferred_type = true; + } + } + return is_inferred_type; + } + + std::string formatTypeMismatchMsg(const std::string& actual_type) const { + std::string inferred_type_hint; + if (is_inferred_type()) { + inferred_type_hint = c10::str( + "Inferred '", + name(), + "' to be of type 'Tensor' ", + "because it was not annotated with an explicit type.\n"); + } + return c10::str( + "Expected a value of type '", + type()->repr_str(), + "' for argument '", + name(), + "' but instead found type '", + actual_type, + "'.\n", + inferred_type_hint); + } + + Argument cloneWithType(const TypePtr& new_type) const { + return Argument( + name_, + new_type, + N_, + default_value_, + kwarg_only_, + alias_info_ ? std::optional(*alias_info_) : std::nullopt); + } + + // this function checks whether this Argument is backward compatible with + // the old one. we consider the following cases are backward compatible: + // 1) two arguments are equal + // 2) this arg's type should be subtype of old + // 3) this arg must provide the same default value if old arg has one, + bool isBackwardCompatibleWith( + const Argument& old, + std::ostream* why_not=nullptr) const; + + // this function checks whether this Argument is forward compatible with + // the old one. we consider the following cases are forward compatible: + // 1) two arguments are equal + // 2) this arg's type should be subtype of old + // 3) this arg must provide the same default value if old arg has one, + bool isForwardCompatibleWith( + const Argument& old, + std::ostream* why_not = nullptr) const; + + private: + std::string name_; + TypePtr type_; + TypePtr real_type_; // this is ScalarType, not int, e.g. + // for list types, an optional statically known length for the list + // e.g. for int[3]: type = ListType::ofInts(), N = 3 + // If present, this will allow scalars to be broadcast to this length to + // become a list. + std::optional N_; + + std::optional default_value_; + // AliasInfo is huge, so let's only allocate memory for it if + // necessary (which it isn't during schema parsing on startup, to + // give a pertinent example). + std::unique_ptr alias_info_; + // is this only specifiable as a keyword argument? + bool kwarg_only_; + // marks if the argument is out variant of the schema + bool is_out_; +}; + +inline bool operator==(const Argument& lhs, const Argument& rhs) { + return lhs.name() == rhs.name() + && *lhs.type() == *rhs.type() + && lhs.N() == rhs.N() + && lhs.default_value() == rhs.default_value() + && lhs.kwarg_only() == rhs.kwarg_only() + && (lhs.alias_info() == rhs.alias_info() + || (lhs.alias_info() != nullptr && rhs.alias_info() != nullptr + && *lhs.alias_info() == *rhs.alias_info())); +} + +inline bool operator!=(const Argument& lhs, const Argument& rhs) { + return !(lhs == rhs); +} + +enum struct TORCH_API SchemaArgType { input, output }; + +/** + * struct SchemaArgument + * + * Structure used to represent arguments or returns for a schema. + */ +struct TORCH_API SchemaArgument { + SchemaArgType type; + size_t index; + SchemaArgument(SchemaArgType tpe, size_t idx) : type(tpe), index(idx) {} + bool operator==(const SchemaArgument& rhs) const { + return type == rhs.type && index == rhs.index; + } +}; + +bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs); + +struct TORCH_API FunctionSchema { + FunctionSchema( + std::string name, + std::string overload_name, + std::vector arguments, + std::vector returns, + bool is_vararg = false, + bool is_varret = false) + : name_({std::move(name), std::move(overload_name)}), + arguments_(std::move(arguments)), + returns_(std::move(returns)), + is_vararg_(is_vararg), + is_varret_(is_varret) { + checkSchema(); + } + + FunctionSchema( + Symbol name, + std::string overload_name, + std::vector arguments, + std::vector returns, + bool is_vararg = false, + bool is_varret = false) + : FunctionSchema( + name.toQualString(), + std::move(overload_name), + std::move(arguments), + std::move(returns), + is_vararg, + is_varret) { + checkSchema(); + } + + // Checks whether this schema is backward compatible with the old one. + // The following conditions must be true: + // [Function structure] The new schema's name, overload-name, varargs, and + // return arity are the same. + // [Output Narrowing] The new schema's output type must be the same class + // or inherit from the old schema's output type. + // [Argument count] The new schema must have at least as many arguments as + // the old schema (considering the list of positional and kwargs). + // [Arg Compatibility] Every argument in the old schema has a corresponding + // argument in the new schema that: + // * is at the same position. + // * has the same name. + // * is either positional, or kwarg and the old argument was kwarg. + // * has the same type, or the old argument's type inherits from the + // new argument's type. + // [Default Values] Every new argument must have a default value. + // E.g. + // OK f_new(a, b, c=1) => f_old(a, b) + // NOK f_new(a, c=1, *, b) => f_old(a, *, b) + // OK f_new(a, b, *, c) => f_old(a, *, b, c) + // NOK f_new(a, *, b, c) -> f_old(a, b, *, c) + // NOK f_new(a, *, c, b) => f_old(a, *, b, c) + // OK f_new(a, *, b, c, d=1) => f_old(a, *, b, c) + bool isBackwardCompatibleWith( + const FunctionSchema& old, + std::ostream* why_not = nullptr) const; + + // Checks whether this schema is forward compatible with the old one. + // The following conditions must be true: + // [Function structure] The new schema's name, overload-name, varargs, and + // return arity are the same. + // [Output Narrowing] The new schema's output type must be the same class + // or inherit from the old schema's output type. + // [Arg Compatibility] Every argument in the old schema has a corresponding + // argument in the new schema that: + // * is at the same position. + // * has the same name. + // * is either positional, or kwarg and the old argument was kwarg. + // * has the same type, or the old argument's type inherits from the + // new argument's type. + // [Default Values] Every new argument must have a default value. + // Each default value type should NOT be a container type. + // [Positioning] All defaults arguments MUST go after either old + // default arguments or the end of positional arguments + // and right BEFORE all out arguments + bool isForwardCompatibleWith( + const FunctionSchema& old, + std::ostringstream& why_not) const; + + private: + OperatorName name_; + std::vector arguments_; + std::vector returns_; + // if true then this schema takes an arbitrary number of additional arguments + // after the argument specified in arguments + // currently this is used primarily to represent 'primitive' operators whose + // arguments are not checked by schema + bool is_vararg_; + bool is_varret_; + + // if no alias information is directly specified, what kind of "default" + // alias information should we infer? + // NB: due to alias analysis kind merging, this may be nullopt. Eventually + // this should always be set no matter what + std::optional alias_kind_; + + template + void checkArg(const IValue& value, const Argument& argument, std::optional pos) const; + + void checkSchema() const { + bool seen_default_arg = false; + for (const auto& arg : arguments()) { + if (arg.default_value()) { + seen_default_arg = true; + } else { + // we have historically serialized broadcasting lists wo/default values, + // so to not break BC allow lists here + if (arg.type()->kind() == ListType::Kind) { + continue; + } + TORCH_INTERNAL_ASSERT( + !seen_default_arg || arg.kwarg_only(), + "Non-default positional argument follows default argument. Parameter ", + arg.name(), + " in ", + *this); + } + } + } + + public: + + void dump() const; + + const OperatorName& operator_name() const { + return name_; + } + const std::string& name() const { + return name_.name; + } + const std::string& overload_name() const { + return name_.overload_name; + } + const std::vector& arguments() const { + return arguments_; + } + const std::vector& returns() const { + return returns_; + } + bool is_vararg() const { + return is_vararg_; + } + bool is_varret() const { + return is_varret_; + } + bool is_aliasing(const c10::SchemaArgument &argument) const { + TORCH_INTERNAL_ASSERT( + argument.index < getCorrectList(argument.type).size(), + "Invalid index for schema."); + const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); + return aliasInfo; + } + bool is_mutable() const { + return std::any_of( + arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) { + const AliasInfo* aliasInfo = arg.alias_info(); + return aliasInfo && aliasInfo->isWrite(); + }); + } + bool is_mutable(const c10::SchemaArgument &argument) const { + TORCH_INTERNAL_ASSERT( + argument.index < getCorrectList(argument.type).size(), + "Invalid index for schema."); + const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); + return aliasInfo && aliasInfo->isWrite(); + } + bool is_mutable(std::string_view name) const { + std::optional index = argumentIndexWithName(name); + TORCH_INTERNAL_ASSERT( + index.has_value(), "Schema has no argument named ", name); + + return is_mutable({c10::SchemaArgType::input, static_cast(*index)}); + } + + // Returns whether lhs and rhs may alias directly. + // This does not account for cases where lhs or rhs are a container that + // may contain elements that alias the other argument. + // FunctionSchema::may_contain_alias will include that functionality. + bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const; + + // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container + // that may contain elements that alias the other argument. + // bidirectional = false only returns whether lhs may contain an alias of rhs + // while bidirectional = true returns both directions. + bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const; + + // Returns whether the two AliasTypeSets contain any similarities + // ie: whether the two type sets can alias. + bool canAliasTypeSetsAlias(const std::optional &lhs, const std::optional &rhs) const; + + // Recursively Finds all contained types within the AliasTypeSet. + std::optional getAliasTypeSetContainedTypes(const std::optional &aliasTypeSet) const; + + // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp. + // Used to map types to a type such that all types that can alias will be mapped to the same type. + // For example, calling this method on 'Optional[List[int]]' is the same as calling this method + // on 'List[int]'. + std::optional mapTypeToAliasTypeSet(const TypePtr& type) const; + + // Returns either arguments() or returns() depending on the SchemaArgType + // output => returns(), input => arguments() + const std::vector& getCorrectList(SchemaArgType type) const; + + std::optional argumentIndexWithName(std::string_view name) const { + for (const auto i : c10::irange(arguments().size())) { + if(name == arguments()[i].name()) + return i; + } + return std::nullopt; + } + FunctionSchema cloneWithName(std::string name, std::string overload_name) const { + return FunctionSchema( + std::move(name), + std::move(overload_name), + arguments(), + returns(), + is_vararg(), + is_varret() + ); + } + FunctionSchema cloneWithArguments(std::vector new_arguments) const { + return FunctionSchema( + name(), + overload_name(), + std::move(new_arguments), + returns(), + is_vararg(), + is_varret()); + } + FunctionSchema cloneWithReturns(std::vector new_returns) const { + return FunctionSchema( + name(), + overload_name(), + arguments(), + std::move(new_returns), + is_vararg(), + is_varret()); + } + + std::string formatTypeMismatchMsg( + const Argument& expected, + const std::string& actual_type, + std::optional position = std::nullopt, + std::optional value = std::nullopt) const; + + FunctionSchema cloneWithRemappedTypes( + const std::function type_map) const; + + FunctionSchema cloneWithRealTypes(bool with_symint=true) const; + + // Check that inputs have the correct types and appends any missing default + // values. + template + void checkAndNormalizeInputs( + std::vector& inputs, + const std::unordered_map& kwargs = + std::unordered_map{}) const; + + std::string findErrorInKwargs(const std::vector& kwargs) const; + + bool hasAnyAliasInfo() const { + for (const auto& arg : arguments_) { + if (arg.alias_info() != nullptr) { + return true; + } + } + for (const auto& ret : returns_) { + if (ret.alias_info() != nullptr) { + return true; + } + } + return false; + } + + + // TODO remove the mutation here + bool isDefaultAliasAnalysisKind() const { + return !alias_kind_; + } + AliasAnalysisKind aliasAnalysis() const { + return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE); + } + void setAliasAnalysis(AliasAnalysisKind v) { + alias_kind_ = v; + } + + std::optional getNamespace() const { + return name_.getNamespace(); + } + + // Returns true if we successfully set the namespace (as there + // was none set, and false otherwise) + bool setNamespaceIfNotSet(const char* ns) { + return name_.setNamespaceIfNotSet(ns); + } + + // can a function with this schema be substituted for a function of rhs's + // schema and have the program typecheck? + // as_method - if true, treat this schema as a method and ignore + // the first argument, which will be the object in both cases + bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const; +}; + +inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) { + return lhs.name() == rhs.name() + && lhs.overload_name() == rhs.overload_name() + && lhs.arguments() == rhs.arguments() + && lhs.returns() == rhs.returns() + && lhs.is_vararg() == rhs.is_vararg() + && lhs.is_varret() == rhs.is_varret(); +} + +inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) { + return !(lhs == rhs); +} + +// print out Argument, which is compatible with FunctionSchema parser +// full format: Type(alias)? name=default_value +inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { + + // for adjusting the ? position. + // in schema, we have Tensor?(a!) input, and t(a!)?. + // however, t?(a!) doesn't work with schema parser. + // so we always use Type(alias)? format + // real_type versus fake_type: in order to be compatible with FunctionSchema + // parser, printing an argument with either MemoryFormat or Layout type should + // give us the original schema string, hence printing out real_type. + auto type = arg.real_type(); + bool is_opt = type->kind() == OptionalType::Kind; + auto unopt_type = is_opt ? type->castRaw()->getElementType() : type; + + if (unopt_type->kind() == ListType::Kind) { + // sized lists get size N from arg, not type + auto list = unopt_type->cast(); + out << list->getElementType()->str(); + if (arg.alias_info() && !arg.alias_info()->containedTypes().empty()){ + out << arg.alias_info()->containedTypes()[0]; + } + std::string N; + if (arg.N()) { + N = std::to_string(*arg.N()); + } + out << "[" << N << "]"; + } else { + out << unopt_type->str(); + } + + // print alias info if it has beforeSets. + if (arg.alias_info() && !arg.alias_info()->beforeSets().empty()) { + out << *arg.alias_info(); + } + + if (is_opt) { + out << "?"; + } + + if (!arg.name().empty()) { + out << " " << arg.name(); + } + + if (arg.default_value()) { + out << "="; + if ((type->kind() == c10::TypeKind::StringType || + unopt_type->kind() == c10::TypeKind::StringType) && + arg.default_value().value().isString()) { + printQuotedString(out, arg.default_value().value().toStringRef()); + } else if (type->kind() == TypeKind::ListType && type->castRaw()->getElementType()->kind() == c10::TypeKind::IntType) { + // We want to faithfully replicate JIT schema. + // in native_functions.yaml defaults for int arrays with a single value always look like + // int[2] stride=1 + // instead of + // int[2] stride=[1, 1] + auto default_val = arg.default_value().value().toIntList(); + if (default_val.size() > 1) { + auto all_defaults_the_same = true; + for (const auto i : c10::irange(1, default_val.size())) { + if (default_val[0] != default_val[i]) all_defaults_the_same = false; + } + if (all_defaults_the_same) { + out << default_val[0]; + } else { + out << arg.default_value().value(); + } + } else { + out << arg.default_value().value(); + } + } else { + out << arg.default_value().value(); + } + } + + return out; +} + +TORCH_API std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema); + +inline std::string toString(const FunctionSchema& schema) { + std::ostringstream str; + str << schema; + return str.str(); +} + +} // namespace c10 + +namespace std { +template<> + struct hash { + size_t operator()(const c10::SchemaArgument& arg) const + { + return c10::hash_combine(std::hash()(arg.index), std::hash()(static_cast(arg.type))); + } + }; +template<> + struct hash { + size_t operator()(const c10::Argument& arg) const + { + auto hash = std::hash{}(arg.name()); + auto type_hash = std::hash{}(arg.type()); + auto kwarg_only_hash = std::hash{}(arg.kwarg_only()); + hash = c10::hash_combine(hash, type_hash); + hash = c10::hash_combine(hash, kwarg_only_hash); + // hashing optional fields if they exist + if (arg.default_value().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto default_value_hash = c10::hash{}(*arg.default_value()); + hash = c10::hash_combine(hash, default_value_hash); + } + if (arg.N().has_value()) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + auto N_hash = std::hash{}(*arg.N()); + hash = c10::hash_combine(hash, N_hash); + } + if (arg.alias_info()) { + auto alias_info_hash = std::hash{}(*arg.alias_info()); + hash = c10::hash_combine(hash, alias_info_hash); + } + return hash; + } + }; +template<> + struct hash { + size_t operator()(const c10::FunctionSchema& schema) const + { + auto hash = std::hash{}(schema.operator_name()); + auto args_hash = c10::hash>{}(schema.arguments()); + auto returns_hash = c10::hash>{}(schema.returns()); + auto is_vararg_hash = std::hash{}(schema.is_vararg()); + auto is_varret_hash = std::hash{}(schema.is_varret()); + hash = c10::hash_combine(hash, args_hash); + hash = c10::hash_combine(hash, returns_hash); + hash = c10::hash_combine(hash, is_vararg_hash); + hash = c10::hash_combine(hash, is_varret_hash); + return hash; + } + }; +} // namespace std + + +#include // IWYU pragma: keep diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function_schema_inl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function_schema_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..0c0715c29abad054d6699903ce0b94885bb8558b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/function_schema_inl.h @@ -0,0 +1,78 @@ +#pragma once +#include +#include + +namespace c10 { + +template +inline void FunctionSchema::checkArg( + const IValue& value, + const Argument& argument, + std::optional pos) const { + if (value.isTensor() && argument.type() == TensorType::get()) { + // Fast-path for the common case + return; + } + if (value.isGenericDict() && value.toGenericDict().empty()) { + return; + } + if (!value.type()->isSubtypeOf(*argument.type())) { + TORCH_CHECK( + false, + formatTypeMismatchMsg( + argument, value.type()->repr_str(), pos)); + } +} + +template +inline void FunctionSchema::checkAndNormalizeInputs( + std::vector& inputs, + const std::unordered_map& kwargs) const { + // Do we have more inputs than the schema accepts? + TORCH_CHECK( + inputs.size() <= arguments().size(), + "Expected at most ", + arguments().size(), + " argument(s) for operator '", + name(), + "', but received ", + inputs.size(), + " argument(s). Declaration: ", + *this); + + size_t consumed_kwargs = 0; + for (const auto pos : c10::irange(arguments().size())) { + const auto& argument = arguments()[pos]; + if (pos < inputs.size()) { + checkArg(inputs[pos], argument, pos); + continue; + } + auto it = kwargs.find(argument.name()); + if (it != kwargs.end()) { + checkArg(it->second, argument, std::nullopt); + inputs.push_back(it->second); + consumed_kwargs++; + continue; + } + if (argument.default_value()) { + inputs.push_back(*argument.default_value()); + continue; + } + TORCH_CHECK(false, + name(), + "() is missing value for argument '", + argument.name(), + "'. Declaration: ", + *this); + } + if (consumed_kwargs != kwargs.size()) { + std::vector names; + names.reserve(kwargs.size()); + for(const auto& k : kwargs) { + names.emplace_back(k.first); + } + TORCH_CHECK(false, findErrorInKwargs(names)); + } +} + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/functional.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..1ddc67418201024f6f94907340689e81f0493c35 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/functional.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +namespace c10 { + +// The passed in function must take T by value (T), or by +// const reference (const T&); taking T by non-const reference +// will result in an error like: +// +// error: no type named 'type' in 'class std::invoke_result' +// +// No explicit template parameters are required. + +// Overload for explicit function and ArrayRef +template +inline auto fmap(const T& inputs, const F& fn) -> std::vector { + std::vector r; + r.reserve(inputs.size()); + for(const auto & input : inputs) + r.push_back(fn(input)); + return r; +} + +// C++ forbids taking an address of a constructor, so here's a workaround... +// Overload for constructor (R) application +template +inline std::vector fmap(const T& inputs) { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) + r.push_back(R(input)); + return r; +} + +template +inline std::vector filter(at::ArrayRef inputs, const F& fn) { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) { + if (fn(input)) { + r.push_back(input); + } + } + return r; +} + +template +inline std::vector filter(const std::vector& inputs, const F& fn) { + return filter(static_cast>(inputs), fn); +} + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/grad_mode.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/grad_mode.h new file mode 100644 index 0000000000000000000000000000000000000000..47051525c59beece9e8e11accacd926d9c5e587e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/grad_mode.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +namespace at { + using GradMode = c10::GradMode; + using AutoGradMode = c10::AutoGradMode; + using NoGradGuard = c10::NoGradGuard; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/interned_strings.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/interned_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..38942031befcd62ed216002f549b67fb547088b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/interned_strings.h @@ -0,0 +1,355 @@ +#pragma once + +#include + +#include +#include + +namespace c10 { + +#define FORALL_NS_SYMBOLS(_) \ + _(namespaces, prim) \ + _(namespaces, prims) \ + _(namespaces, nvprims) \ + _(namespaces, aten) \ + _(namespaces, cuda) \ + _(namespaces, onnx) \ + _(namespaces, attr) \ + _(namespaces, scope) \ + _(namespaces, user) \ + _(namespaces, _caffe2) \ + _(namespaces, dimname) \ + _(namespaces, namespaces) \ + _(prim, Assign) \ + _(prim, BroadcastingChunk) \ + _(prim, BroadcastSizes) \ + _(prim, ReductionSizes) \ + _(prim, Constant) \ + _(prim, ChunkSizes) \ + _(prim, ConstantMKLDNNTensor) \ + _(prim, BroadcastMKLDNNTensors) \ + _(prim, MKLDNNGroup) \ + _(prim, MKLDNNHardSwish) \ + _(prim, MKLDNNHardSigmoid) \ + _(prim, MKLDNNHardTanh) \ + _(prim, MKLDNNClamp) \ + _(prim, StaticRuntimeCopyOuts) \ + _(prim, Drop) \ + _(prim, Eval) \ + _(prim, Expand) /* onnx */ \ + _(prim, FusionGroup) \ + _(prim, CudaFusionGroup) \ + _(prim, CudaFusionGuard) \ + _(prim, oneDNNFusionGroup) \ + _(prim, oneDNNFusionGuard) \ + _(prim, FunctionalGraph) \ + _(prim, add_optional) \ + _(prim, view_copy) \ + _(prim, permute_copy) \ + _(prim, reshape_copy) \ + _(prim, squeeze_copy) \ + _(prim, t_copy) \ + _(prim, transpose_copy) \ + _(prim, unsqueeze_copy) \ + _(prim, flatten_copy) \ + _(prim, expand_copy) \ + _(prim, expand_as_copy) \ + _(prim, DifferentiableGraph) \ + _(prim, TensorExprGroup) \ + _(prim, TensorExprDynamicGroup) \ + _(prim, StaticSubgraph) \ + _(prim, If) \ + _(prim, Jump) /* debug */ \ + _(prim, JumpNZ) /* debug */ \ + _(prim, JumpZ) /* debug */ \ + _(prim, Load) \ + _(prim, Loop) \ + _(prim, Param) \ + _(prim, PackPadded) /* onnx */ \ + _(prim, PadPacked) /* onnx */ \ + _(prim, Placeholder) /* debug */ \ + _(prim, Print) \ + _(prim, EmptyListLiteral) \ + _(prim, LegacyTypedConstructor) \ + _(prim, PythonOp) \ + _(prim, IgnoredPythonOp) \ + _(prim, Reverse) \ + _(prim, Return) \ + _(prim, ReturnStmt) \ + _(prim, BreakStmt) \ + _(prim, ContinueStmt) \ + _(prim, ComprehensionScope) \ + _(prim, Store) \ + _(prim, AutogradZero) \ + _(prim, AutogradAnyNonZero) \ + _(prim, AutogradAllNonZero) \ + _(prim, AutogradAllZero) \ + _(prim, Starred) \ + _(prim, TupleConstruct) \ + _(prim, TupleUnpack) \ + _(prim, TupleIndex) \ + _(prim, TupleSlice) \ + _(prim, ListConstruct) \ + _(prim, ListUnpack) \ + _(prim, DictConstruct) \ + _(prim, ModuleContainerIndex) \ + _(prim, EnumName) \ + _(prim, EnumValue) \ + _(prim, StringIndex) \ + _(prim, NumToTensor) \ + _(prim, Uninitialized) \ + _(prim, VarConcat) \ + _(prim, VarStack) \ + _(prim, With) \ + _(prim, Enter) \ + _(prim, Exit) \ + _(prim, IfThenElse) \ + _(aten, Bool) \ + _(aten, Int) \ + _(aten, FloatImplicit) \ + _(aten, ComplexImplicit) \ + _(aten, IntImplicit) \ + _(aten, ScalarImplicit) \ + _(aten, Float) \ + _(aten, Complex) \ + _(aten, str) \ + _(aten, Delete) \ + _(prim, device) \ + _(prim, dtype) \ + _(prim, layout) \ + _(prim, id) \ + _(prim, requires_grad) \ + _(prim, MakeTestTensor) /* test */ \ + _(prim, AutogradAdd) \ + _(prim, GradOf) \ + _(aten, grad) \ + _(aten, backward) \ + _(prim, Guard) \ + _(prim, BailOut) \ + _(prim, TypeCheck) \ + _(prim, RequiresGradCheck) \ + _(prim, FallbackGraph) \ + _(prim, FusedConcat) \ + _(prim, ConstantChunk) \ + _(prim, MMTreeReduce) \ + _(prim, MMBatchSide) \ + _(prim, list) \ + _(prim, dict) \ + _(prim, min) \ + _(prim, max) \ + _(prim, abs) \ + _(aten, divmod) \ + _(prim, zip) \ + _(prim, enumerate) \ + _(prim, range) \ + _(prim, rangelist) \ + _(prim, isinstance) \ + _(prim, tolist) \ + _(prim, unchecked_cast) \ + _(aten, _grad_sum_to_size) \ + _(aten, _size_if_not_equal) \ + _(aten, _ncf_unsqueeze) \ + _(aten, warn) \ + _(aten, sorted) \ + _(aten, floordiv) \ + _(aten, __range_length) \ + _(aten, __derive_index) \ + _(aten, __round_to_zero_floordiv) \ + _(aten, is_scripting) \ + _(aten, _unwrap_optional) \ + _(prim, fork) \ + _(prim, awaitable) \ + _(prim, forkClosure) \ + _(prim, awaitableClosure) \ + _(prim, awaitable_nowait) \ + _(prim, awaitable_wait) \ + _(prim, RaiseException) \ + _(prim, Closure) \ + _(prim, CreateObject) \ + _(prim, SetAttr) \ + _(prim, GetAttr) \ + _(prim, HasAttr) \ + _(prim, profile) \ + _(prim, profile_ivalue) \ + _(prim, AddStatValue) \ + _(prim, TimePoint) \ + _(prim, CallFunction) \ + _(prim, CallMethod) \ + _(prim, LoopContinuation) \ + _(prim, annotate) \ + _(prim, TracedModuleForward) \ + _(prim, TracedFork) \ + _(prim, TracedAttr) \ + _(prim, rpc_async) \ + _(prim, rpc_sync) \ + _(prim, rpc_remote) \ + _(prim, is_cuda) \ + _(aten, append) \ + _(aten, as_tensor) \ + _(aten, adaptive_avg_pool2d_backward) \ + _(aten, dim) \ + _(aten, format) \ + _(aten, percentFormat) \ + _(aten, __not__) \ + _(aten, __is__) \ + _(aten, __isnot__) \ + _(aten, _ger) \ + _(aten, __getitem__) \ + _(aten, _set_item) \ + _(aten, manual_seed) \ + _(aten, device) \ + _(aten, hash) \ + _(aten, len) \ + _(aten, list) \ + _(aten, dict) \ + _(aten, wait) \ + _(aten, save) \ + _(aten, keys) \ + _(aten, ord) \ + _(aten, chr) \ + _(aten, hex) \ + _(aten, oct) \ + _(aten, clear) \ + _(aten, setdefault) \ + _(aten, bin) \ + _(aten, pop) \ + _(aten, insert) \ + _(aten, tensor) \ + _(prim, unchecked_unwrap_optional) \ + _(aten, __contains__) \ + _(prim, BailoutTemplate) \ + _(prim, grad) \ + _(cuda, _set_device) \ + _(cuda, set_stream) \ + _(cuda, _current_device) \ + _(cuda, synchronize) \ + _(aten, has_torch_function) \ + _(aten, is_autocast_enabled) \ + _(aten, is_autocast_cpu_enabled) \ + _(aten, is_autocast_xla_enabled) \ + _(aten, get_autocast_dtype) \ + _(aten, is_autocast_mps_enabled) \ + FORALL_ATEN_BASE_SYMBOLS(_) \ + _(onnx, Add) \ + _(onnx, Concat) \ + _(onnx, Constant) \ + _(onnx, ConstantFill) \ + _(onnx, Div) \ + _(onnx, GRU) \ + _(onnx, Gather) \ + _(onnx, Gemm) \ + _(onnx, LSTM) \ + _(onnx, MatMul) \ + _(onnx, Min) \ + _(onnx, Max) \ + _(onnx, Mul) \ + _(onnx, Pow) \ + _(onnx, RNN) \ + _(onnx, Shape) \ + _(onnx, Size) \ + _(onnx, Slice) \ + _(onnx, Softmax) \ + _(onnx, Squeeze) \ + _(onnx, Sub) \ + _(onnx, Transpose) \ + _(onnx, Unsqueeze) \ + _(onnx, Loop) \ + _(onnx, If) \ + _(onnx, Reshape) \ + _(onnx, Expand) \ + _(onnx, Equal) \ + _(onnx, Greater) \ + _(onnx, GreaterOrEqual) \ + _(onnx, Less) \ + _(onnx, LessOrEqual) \ + _(onnx, Not) \ + _(aten, ATen) \ + _(onnx, Split) \ + _(onnx, ConstantOfShape) \ + _(onnx, Cast) \ + _(onnx, Mod) \ + _(onnx, Sqrt) \ + _(onnx, SplitToSequence) \ + _(onnx, SequenceAt) \ + _(onnx, SequenceConstruct) \ + _(onnx, SequenceEmpty) \ + _(onnx, SequenceInsert) \ + _(onnx, SequenceErase) \ + _(onnx, ConcatFromSequence) \ + _(onnx, Identity) \ + _(onnx, SoftmaxCrossEntropyLoss) \ + _(onnx, NegativeLogLikelihoodLoss) \ + _(onnx, LogSoftmax) \ + _(onnx, ReduceL1) \ + _(onnx, ReduceL2) \ + _(onnx, Conv) \ + _(onnx, BatchNormalization) \ + _(onnx, ReduceMean) \ + _(onnx, ReduceProd) \ + _(onnx, Relu) \ + _(onnx, Neg) \ + _(onnx, NonZero) \ + _(onnx, Range) \ + _(onnx, Tile) \ + _(onnx, Where) \ + _(onnx, Optional) \ + _(onnx, OptionalGetElement) \ + _(onnx, OptionalHasElement) \ + FORALL_ATTR_BASE_SYMBOLS(_) \ + _(attr, Subgraph) \ + _(attr, ReverseSubgraph) \ + _(attr, f_real_outputs) \ + _(attr, df_input_vjps) \ + _(attr, df_input_captured_inputs) \ + _(attr, df_input_captured_outputs) \ + _(attr, df_output_vjps) \ + _(attr, axes) \ + _(attr, symbolic_shape_inputs) \ + _(attr, allow_stack_outputs) \ + _(attr, striding_inputs_desc) \ + _(attr, striding_outputs_desc) \ + _(attr, broadcast) \ + _(attr, direction) \ + _(attr, ends) \ + _(attr, inplace) \ + _(attr, input_as_shape) \ + _(attr, is_zero) \ + _(attr, num_none) \ + _(attr, num_present) \ + _(attr, perm) \ + _(attr, starts) \ + _(attr, profiled_type) \ + _(attr, transA) \ + _(attr, transB) \ + _(attr, name) \ + _(attr, module) \ + _(attr, beg) \ + _(attr, idx) \ + _(attr, split) \ + _(attr, slot) \ + _(attr, kinds) \ + _(attr, types) \ + _(attr, scope) \ + _(attr, keepdims) \ + _(attr, cache_id) \ + _(attr, new_axis) \ + _(attr, warn_id) \ + _(attr, output_layouts) \ + _(attr, allowzero) \ + _(attr, seen_none) \ + _(attr, overload_name) \ + _(attr, node_stack_idx) + +enum class _keys : unique_t { + #define DEFINE_KEY(ns, s) ns##_##s, + FORALL_NS_SYMBOLS(DEFINE_KEY) + #undef DEFINE_KEY + num_symbols +}; + +#define DEFINE_SYMBOL(ns, s) \ + namespace ns { constexpr Symbol s(static_cast(_keys::ns##_##s)); } +FORALL_NS_SYMBOLS(DEFINE_SYMBOL) +#undef DEFINE_SYMBOL + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/interned_strings_class.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/interned_strings_class.h new file mode 100644 index 0000000000000000000000000000000000000000..a215fa62c7e91827425a86df4c556428bf249bea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/interned_strings_class.h @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include +#include + +namespace c10 { + +struct TORCH_API InternedStrings { + InternedStrings(); + Symbol symbol(const std::string& s); + std::pair string(Symbol sym); + Symbol ns(Symbol sym); + + private: + // prereq - holding mutex_ + Symbol _symbol(const std::string& s); + std::pair customString(Symbol sym); + std::unordered_map string_to_sym_; + + struct SymbolInfo { + Symbol ns; + std::string qual_name; + std::string unqual_name; + }; + std::vector sym_to_info_; + + std::mutex mutex_; +}; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue.h new file mode 100644 index 0000000000000000000000000000000000000000..175860dc99a7ccbe3742d27d24346b136eee4e12 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue.h @@ -0,0 +1,1589 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {}; +namespace jit { +using ::torch::CustomClassHolder; +struct Function; +struct CompilationUnit; +struct Module; +} // namespace jit +} // namespace torch +namespace c10 { +template +class Dict; +template +class List; +template +class IListRef; +struct IValue; +struct ClassType; +struct Type; +class RRefInterface; + +struct ClassType; +using ClassTypePtr = std::shared_ptr; + +TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs); + +TORCH_API torch::jit::Function* checkObjectSortSchema( + const c10::ClassTypePtr& t, + std::stringstream& why_not); + +// A comparator that checks ordering of two IValues of same type. +typedef std::function IValueComparator; + +TORCH_API IValueComparator getLessThanComparator(const IValue& v); +TORCH_API IValueComparator getGreaterThanComparator(const IValue& v); + +namespace ivalue { +struct Tuple; +struct Future; +struct Await; +struct ConstantString; +struct GenericDict; +struct Object; +struct PyObjectHolder; +struct EnumHolder; +// We need a ComplexHolder because currently the payloads in the Union +// only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big +// to fit in the IValue directly, we indirect complex numbers through an +// intrusive pointer to ComplexHolder (which contains a c10::complex). +struct ComplexHolder : c10::intrusive_ptr_target { + public: + template + ComplexHolder(c10::complex c) { + val = convert>(c); + } + ComplexHolder() = default; + c10::complex val; +}; + +// Similar to ComplexHolder, for StreamData3 +struct StreamData3Holder : c10::intrusive_ptr_target { + public: + StreamData3Holder(struct c10::StreamData3 d) : val(d) {} + StreamData3Holder() = delete; + struct c10::StreamData3 val; +}; + +} // namespace ivalue + +// This is an owning wrapper for a std::optional> +// that can be implicitly converted to a (non-owning) std::optional>. +// Its purpose is to be used in generated code to keep the vector alive +// either until the end of a statement (as a temporary), or as a saved arg +// in autograd. +template +struct OptionalArray { + std::optional> list; + + OptionalArray() = default; + OptionalArray(std::vector val) : list(std::move(val)) {} + + // Used when saving an argument for the backwards pass. + OptionalArray& operator=(std::optional> ref) { + if (ref) { + list = std::vector(ref->begin(), ref->end()); + } else { + list = std::nullopt; + } + return *this; + } + + // Used when saving an argument for the backwards pass. + OptionalArray& operator=(c10::OptionalArrayRef ref) { + if (ref) { + list = std::vector(ref->begin(), ref->end()); + } else { + list = std::nullopt; + } + return *this; + } + + operator std::optional>() { + if (!list) { + return std::nullopt; + } + return *list; + } + + operator c10::OptionalArrayRef() { + if (!list) { + return std::nullopt; + } + return *list; + } +}; + +// Capsule is an internal implementation detail of custom C++ classes. We +// define it as an owning wrapper for +// c10::intrusive_ptr This wrapper is here to serve as +// an abstraction of the type erased custom class object pointer. It also allow +// pybind11 to treat this as a standalone class to register as a separate type +// caster, instead of a custom pointer holder which the pointer holder type +// caster try to "unwrap" it automatically. +struct Capsule { + c10::intrusive_ptr obj_ptr; + explicit Capsule(c10::intrusive_ptr ptr) + : obj_ptr(std::move(ptr)) {} +}; + +// IValue is the generic tagged union used by the interpreter to hold +// all value types. +// It is a 16-byte object with an 8-byte payload and an 8-byte tag. +// The tag is currently 4 bytes to determine the type, and 1 byte +// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs +// retain/release calls. + +#define TORCH_FORALL_TAGS(_) \ + _(None) \ + _(Tensor) \ + _(Storage) \ + _(Double) \ + _(ComplexDouble) \ + _(Int) \ + _(SymInt) \ + _(SymFloat) \ + _(SymBool) \ + _(Bool) \ + _(Tuple) \ + _(String) \ + _(Blob) \ + _(GenericList) \ + _(GenericDict) \ + _(Future) \ + _(Await) \ + _(Device) \ + _(Stream) \ + _(Object) \ + _(PyObject) \ + _(Uninitialized) \ + _(Capsule) \ + _(RRef) \ + _(Quantizer) \ + _(Generator) \ + _(Enum) + +// [doxygen private] +// These methods are not actually private but we don't want to document them, so +// they are marked `@private`, which hides them on the doxygen documentation for +// this page. + +/// IValue (Interpreter Value) is a tagged union over the types +/// supported by the TorchScript interpreter. IValues contain their +/// values as an `IValue::Payload`, which holds primitive types +/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values, +/// and all other types as a `c10::intrusive_ptr`. In order to +/// optimize performance of the destructor and related operations by +/// making the `Tensor` and `c10::intrusive_ptr` paths generate the +/// same code, we represent a null `c10::intrusive_ptr` as +/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`. +/// +/// IValues are used as inputs to and outputs from the TorchScript interpreter. +/// To retrieve the value contained within an IValue, use the `.toX()` methods, +/// where `X` is the type you are trying to get. Note that neither the `.toX()` +/// methods nor the templated `.to` functions do any kind of casting, they +/// only unwrap the contained value. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// // Make the IValue +/// torch::IValue my_ivalue(26); +/// std::cout << my_ivalue << "\n"; +/// +/// // Unwrap the IValue +/// int64_t my_int = my_ivalue.toInt(); +/// std::cout << my_int << "\n"; +/// +/// // This will throw an error! +/// // `my_ivalue` is tagged as an int and cannot be used as another type +/// torch::Tensor my_tensor = my_ivalue.toTensor(); +/// \endrst +struct TORCH_API IValue final { + IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag) { + if (isIntrusivePtr() && + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); + } + } + + IValue(IValue&& rhs) noexcept : tag(rhs.tag) { + moveFrom(std::move(rhs)); + } + + /// @private [doxygen private] + ~IValue() { + destroy(); + } + + C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { + if (&rhs == this) { + return *this; + } + + destroy(); + moveFrom(std::move(rhs)); + return *this; + } + + IValue& operator=(IValue const& rhs) & { + *this = IValue(rhs); + return *this; + } + + void dump() const; + + /** + * Equality comparison. The semantics are the same as Python's `==`: + * 1. Numerical types are compared by value. + * 2. Tensors compute element-wise equality, returning a BoolTensor (see: + * `torch.eq()`) + * 3. Strings are compared by value. + * 4. Sequence types (list, tuple) are compared lexicographically by + * comparing their elements. Different sequence types never compare equal. + * 5. Mappings (dict) must have equal (key, value) pairs. + * 6. If not listed above, the default behavior for is to test identity + * equality (e.g. pointer equality). + * + * Why does this return an IValue instead of a bool? Because in PyTorch, + * `tensor1 == tensor2` returns a `BoolTensor`, not a bool. + * + * NOTE: we (like Python) assume that identity equality implies value equality + * for efficiency. + * TODO: need to support customizing equality + */ + IValue equals(const IValue& rhs) const; + /** + * This implements the same semantics as `bool(lhs == rhs)` in Python. which + * is the same as `equals()` except for Tensor types. + */ + TORCH_API friend bool operator==(const IValue& lhs, const IValue& rhs); + TORCH_API friend bool operator!=(const IValue& lhs, const IValue& rhs); + + /** + * Identity comparison. Checks if `this` is the same object as `rhs`. The + * semantics are the same as Python's `is` operator. + * + * NOTE: Like in Python, this operation is poorly defined for primitive types + * like numbers and strings. Prefer to use `==` unless you really want to + * check identity equality. + */ + bool is(const IValue& rhs) const; + + /** + * Hashing for IValues. Returns an IValue-boxed int. + * + * Some notes: + * - Like eager, Tensors are hashed by looking at the pointer. This is not + * strictly correct because two value-equal tensors with different tensor + * pointers will hash differently, but we choose to reproduce the eager + * semantics. + * - Hashing is not defined on all built-in IValue types (e.g. list and + * dict), following Python. Calling `hash()` on these types will throw. + */ + IValue hash() const { + return (int64_t)IValue::hash(*this); + } + // This is defined because `c10::hash` dispatches to a function of this + // signature. See the member function `hash()`. + static size_t hash(const IValue& iv); + + /** + * @private [doxygen private] + * [container equality] + * This is an equality implementation that assumes objects with the same + * identity equal themselves, for efficiency reasons. We primarily have this + * for consistency, because Python does the same thing. This actually + * provokes user-visible changes in behavior due to quirks in torch: + * [tensor1] == [tensor1] -> True (because container equality will first + * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError: + * Boolean value of Tensor with more than one value is ambiguous + */ + TORCH_API friend bool _fastEqualsForContainer( + const IValue& lhs, + const IValue& rhs); + + private: + static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) { + if (a.is_sparse()) { + return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b); + } + if (b.is_sparse()) { + return isAliasOf(a, b._values()) || isAliasOf(a, b._indices()); + } + if (a.is_sparse_csr()) { + return isAliasOf(a.values(), b) || isAliasOf(a.crow_indices(), b) || + isAliasOf(a.col_indices(), b); + } + if (b.is_sparse_csr()) { + return isAliasOf(a, b.values()) || isAliasOf(a, b.crow_indices()) || + isAliasOf(a, b.col_indices()); + } + + // Opaque tensors such as the ones constructed by the MKL-DNN backend + // don't have storage so we just compare their TensorImpls. + // TODO: Find way to expose alias info for opaque tensors. + if (!a.has_storage() || !b.has_storage()) { + return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl(); + } + + return a.is_alias_of(b); + } + + template + bool isListOf() const; + + public: + /// @private [doxygen private] + bool isAliasOf(const IValue& rhs) const { + if (this->tag != rhs.tag) { + // Trivially don't alias if the type is different + return false; + } + + // Tensors should be compared based on internal storage + if (this->isTensor()) { + return isAliasOf(this->toTensor(), rhs.toTensor()); + } + + if (!isIntrusivePtr()) { + // Primitive types don't alias anything + return false; + } + + AT_ASSERT(rhs.isIntrusivePtr()); + + // Other types can be compared by their ptr value + return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; + } + + /// @private [doxygen private] + size_t use_count() const noexcept { + if (isTensor()) { + return payload.as_tensor.use_count(); + } + + if (!isIntrusivePtrLegacyBehavior()) { + return 1; + } + + if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { + return 0; + } + return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr); + } + + /// @private [doxygen private] + void swap(IValue& rhs) noexcept { + if (isTensor() && rhs.isTensor()) { + std::swap(payload.as_tensor, rhs.payload.as_tensor); + } else if (isTensor()) { + at::Tensor t = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + payload.u = rhs.payload.u; + new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); + } else if (rhs.isTensor()) { + rhs.swap(*this); + return; + } else { + std::swap(payload.u, rhs.payload.u); + } + std::swap(tag, rhs.tag); + } + + // Accessors for subtypes are arranged together below + // While some of these accessors could be generated through templates, + // we prefer to write them manually for clarity + + IValue(at::TensorBase t) : tag(Tag::Tensor) { + new (&payload.as_tensor) at::Tensor(std::move(t)); + } + bool isTensor() const { + return Tag::Tensor == tag; + } + + private: + // Outlined error path so that toTensor() can be inlined. + [[noreturn]] void reportToTensorTypeError() const; + + public: + at::Tensor toTensor() &&; + at::Tensor& toTensor() &; + const at::Tensor& toTensor() const&; + at::TensorImpl* unsafeToTensorImpl() const { + TORCH_INTERNAL_ASSERT(isTensor()); + return payload.as_tensor.unsafeGetTensorImpl(); + } + + IValue(at::Storage s) : tag(Tag::Storage) { + payload.u.as_intrusive_ptr = + null_to_undefined_tensor(s.unsafeReleaseStorageImpl()); + } + bool isStorage() const { + return Tag::Storage == tag; + } + c10::Storage toStorage() &&; + c10::Storage toStorage() const&; + + const IValue& toIValue() const { + return *this; + } + IValue& toIValue() { + return *this; + } + + /// @private [doxygen private] + IValue(intrusive_ptr blob) : tag(Tag::Blob) { + // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract + // and store it as a Tensor instead. + payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); + } + + /// @private [doxygen private] + bool isBlob() const { + return Tag::Blob == tag; + } + + /// @private [doxygen private] + c10::intrusive_ptr toBlob() &&; + + /// @private [doxygen private] + c10::intrusive_ptr toBlob() const&; + + // Capsule. No new callsites of these APIs should + // be introduced. + static inline IValue make_capsule( + intrusive_ptr blob); + bool isCapsule() const { + return Tag::Capsule == tag; + } + c10::intrusive_ptr toCapsule() &&; + c10::intrusive_ptr toCapsule() const&; + + // Custom C++ classes + template < + typename T, + std::enable_if_t, int> = 0> + IValue(intrusive_ptr custom_class); + bool isCustomClass() const; + template + c10::intrusive_ptr toCustomClass() &&; + template + c10::intrusive_ptr toCustomClass() const&; + + // Tuple + IValue(c10::intrusive_ptr v); + + template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t> = nullptr> + IValue(const std::tuple& t); + template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t> = nullptr> + IValue(std::tuple&& t); + bool isTuple() const { + return Tag::Tuple == tag; + } + c10::intrusive_ptr toTuple() &&; + c10::intrusive_ptr toTuple() const&; + [[nodiscard]] ivalue::Tuple& toTupleRef() const; + + // Double + IValue(double d) : tag(Tag::Double) { + payload.u.as_double = d; + } + bool isDouble() const { + return Tag::Double == tag; + } + double toDouble() const { + if (isDouble()) { + return payload.u.as_double; + } else if (isSymFloat()) { + return toSymFloat().guard_float(__FILE__, __LINE__); + } else { + TORCH_INTERNAL_ASSERT(0, "expected double"); + } + } + + // ComplexDouble + template + IValue(c10::complex c); + bool isComplexDouble() const { + return Tag::ComplexDouble == tag; + } + c10::complex toComplexDouble() const; + + // Future + IValue(c10::intrusive_ptr v); + bool isFuture() const { + return Tag::Future == tag; + } + c10::intrusive_ptr toFuture() &&; + c10::intrusive_ptr toFuture() const&; + + IValue(c10::intrusive_ptr v); + bool isAwait() const { + return Tag::Await == tag; + } + c10::intrusive_ptr toAwait() &&; + c10::intrusive_ptr toAwait() const&; + + // RRef + IValue(c10::intrusive_ptr v); + bool isRRef() const { + return Tag::RRef == tag; + } + c10::intrusive_ptr toRRef() &&; + c10::intrusive_ptr toRRef() const&; + + // Quantizer + IValue(c10::intrusive_ptr v); + bool isQuantizer() const { + return Tag::Quantizer == tag; + } + c10::intrusive_ptr toQuantizer() &&; + c10::intrusive_ptr toQuantizer() const&; + + // Int + IValue(int64_t i) : tag(Tag::Int) { + payload.u.as_int = i; + } + + IValue(const c10::SymInt& i) { + if (auto mi = i.maybe_as_int()) { + tag = Tag::Int; + payload.u.as_int = *mi; + } else { + tag = Tag::SymInt; + payload.u.as_intrusive_ptr = i.toSymNode().release(); + } + } + + bool isSymInt() const { + return Tag::SymInt == tag; + } + + c10::SymInt toSymInt() &&; + c10::SymInt toSymInt() const&; + + IValue(const c10::SymFloat& i) { + if (i.is_symbolic()) { + tag = Tag::SymFloat; + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); + } else { + tag = Tag::Double; + payload.u.as_double = i.as_float_unchecked(); + } + } + + bool isSymFloat() const { + return Tag::SymFloat == tag; + } + + c10::SymFloat toSymFloat() &&; + c10::SymFloat toSymFloat() const&; + + IValue(const c10::SymBool& i) { + if (auto mi = i.maybe_as_bool()) { + tag = Tag::Bool; + payload.u.as_int = *mi; + } else { + tag = Tag::SymBool; + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); + } + } + + bool isSymBool() const { + return Tag::SymBool == tag; + } + + c10::SymBool toSymBool() &&; + c10::SymBool toSymBool() const&; + + // allow you to pass literals (3, 4) without ambiguity + IValue(int32_t i) : IValue(static_cast(i)) {} + + bool isInt() const { + return Tag::Int == tag; + } + + int64_t toInt() const { + if (isInt()) { + return payload.u.as_int; + } else if (isSymInt()) { + return toSymInt().guard_int(__FILE__, __LINE__); + } else { + TORCH_INTERNAL_ASSERT(0, "expected int"); + } + } + + // Bool + IValue(bool b) : tag(Tag::Bool) { +#if defined(__clang__) && defined(__x86_64__) + // Initializing entire payload stops valgrind's from reporting + // "jump or move depends on uninitialised value" in IValue copy constructor + // See https://github.com/pytorch/pytorch/issues/37117 + payload.u.as_int = b; +#else + payload.u.as_bool = b; +#endif + } + bool isBool() const { + return Tag::Bool == tag; + } + bool toBool() const { + if (isBool()) { + return payload.u.as_bool; + } else if (isSymBool()) { + return toSymBool().guard_bool(__FILE__, __LINE__); + } else { + TORCH_INTERNAL_ASSERT(0, "expected bool"); + } + } + + // IntList + bool isIntList() const; + bool isSymIntList() const; + c10::List toIntList() &&; + c10::List toIntList() const&; + std::vector toIntVector() const; + c10::List toSymIntList() &&; + c10::List toSymIntList() const&; + std::vector toSymIntVector() const; + at::DimVector toDimVector() const; + + // ConstantString + IValue(c10::intrusive_ptr v); + IValue(std::string v); + IValue(const char* v) : IValue(std::string(v)) {} + IValue(std::string_view v) : IValue(std::string(v)){} + bool isString() const { + return Tag::String == tag; + } + c10::intrusive_ptr toString() &&; + c10::intrusive_ptr toString() const&; + const std::string& toStringRef() const; + std::optional> toOptionalStringRef() + const; + std::string_view toStringView() const; + + // DoubleList + bool isDoubleList() const; + c10::List toDoubleList() &&; + c10::List toDoubleList() const&; + std::vector toDoubleVector() const; + + // ComplexDoubleList + bool isComplexDoubleList() const; + c10::List> toComplexDoubleList() &&; + c10::List> toComplexDoubleList() const&; + std::vector> toComplexDoubleVector() const; + + // BoolList + bool isBoolList() const; + c10::List toBoolList() &&; + c10::List toBoolList() const&; + + // TensorList + bool isTensorList() const; + c10::List toTensorList() &&; + c10::List toTensorList() const&; + std::vector toTensorVector() const; + + // OptionalTensorList + bool isOptionalTensorList() const; + c10::List> toOptionalTensorList() &&; + c10::List> toOptionalTensorList() const&; + std::vector> toOptionalTensorVector() const; + + // GenericList + IValue(c10::List v); + bool isList() const { + return Tag::GenericList == tag; + } + c10::List toList() &&; + c10::List toList() const&; + c10::ArrayRef toListRef() const; + + // Some template constructors of IValue calls another constructor recursively. + // This SFINAEs the called constructor exists. + template + using enable_if_ivalue_constructible = + std::enable_if_t, std::nullptr_t>; + + // The rule for lists is more complicated; the generic constructor is only + // acceptable if your element isn't SymInt. If you do have a SymInt element, + // then you must also, at construction time, check if you can decay the list + // into an int list (this is MANDATORY, as at a use site we may expect + // toIntList to work even if at the call site you had a SymIntArrayRef + // argument). In practice, only SymIntArrayRef is used this way, so we + // didn't bother making it work for the other constructors, we just make sure + // they're not selectable. + template + using enable_if_list_is_ivalue_constructible = std::enable_if_t< + std::is_constructible_v && !std::is_same_v, + std::nullptr_t>; + + template = nullptr> + IValue(c10::List&& v); + template = nullptr> + IValue(const c10::List& v); + template = nullptr> + IValue(at::ArrayRef v); + template = nullptr> + IValue(const std::vector& v); + template = nullptr> + IValue(std::vector&& v); + template + IValue(std::array v); + + // Manual constructors for lists of symints, which decay to int list if + // possible. To avoid ambiguous overload situations, we template them + // to prevent implicit conversions + template + using enable_if_symint = + std::enable_if_t, std::nullptr_t>; + + template = nullptr> + IValue(at::ArrayRef v); + template = nullptr> + IValue(at::OptionalArrayRef v); + template = nullptr> + IValue(const std::vector& v); + template = nullptr> + IValue(std::vector&& v); + + template + using enable_if_ilist_is_ivalue_constructible = std::enable_if_t< + std::is_constructible_v && + std::is_constructible_v::boxed_type> && + !std::is_same_v, + std::nullptr_t>; + + template = nullptr> + IValue(c10::IListRef v); + + // GenericDict + IValue(c10::Dict v); + bool isGenericDict() const { + return Tag::GenericDict == tag; + } + c10::Dict toGenericDict() &&; + c10::Dict toGenericDict() const&; + + template + IValue(c10::Dict v); + + template + /// \cond + /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN + C10_DEPRECATED_MESSAGE( + "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") + /// \endcond + IValue(std::unordered_map v); + + template = nullptr> + IValue(std::optional v); + template = nullptr> + IValue(c10::OptionalArrayRef v); + IValue(std::nullopt_t); + + // ClassType + IValue(c10::intrusive_ptr v); + bool isObject() const { + return tag == Tag::Object; + } + c10::intrusive_ptr toObject() &&; + c10::intrusive_ptr toObject() const&; + ivalue::Object& toObjectRef() const; + + torch::jit::Module toModule() const; + bool isModule() const; + + // PyObject + IValue(c10::intrusive_ptr v); + bool isPyObject() const { + return tag == Tag::PyObject; + } + c10::intrusive_ptr toPyObjectHolder() &&; + c10::intrusive_ptr toPyObjectHolder() const&; + PyObject* toPyObject() const; + + // Enum + explicit IValue(c10::intrusive_ptr v); + bool isEnum() const { + return tag == Tag::Enum; + } + c10::intrusive_ptr toEnumHolder() &&; + c10::intrusive_ptr toEnumHolder() const&; + + // None + IValue() = default; + bool isNone() const { + return Tag::None == tag; + } + std::string toNone() const { + AT_ASSERT(isNone()); + return "None"; + } + + static IValue uninitialized() { + auto i = IValue(); + i.tag = Tag::Uninitialized; + return i; + } + + // Scalar, which gets encoded as either an Int, a Double or a ComplexDouble + IValue(const at::Scalar& s) : IValue() { + // NB: do the symbolic versions first, as isFloatingPoint is true + // for both SymFloat and double + if (s.isSymInt()) { + tag = Tag::SymInt; + payload.u.as_intrusive_ptr = s.toSymInt().toSymNode().release(); + } else if (s.isSymFloat()) { + tag = Tag::SymFloat; + payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release(); + } else if (s.isSymBool()) { + tag = Tag::SymBool; + payload.u.as_intrusive_ptr = s.toSymBool().toSymNodeImpl().release(); + } else if (s.isFloatingPoint()) { + tag = Tag::Double; + payload.u.as_double = s.toDouble(); + } else if (s.isComplex()) { + *this = s.toComplexDouble(); + } else if (s.isBoolean()) { + tag = Tag::Bool; + payload.u.as_bool = s.toBool(); + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + s.isIntegral(false), "Unknown type in Scalar"); + tag = Tag::Int; + payload.u.as_int = s.toLong(); + } + } + + bool isScalar() const { + return isDouble() || isInt() || isComplexDouble() || isBool() || + isSymInt() || isSymFloat() || isSymBool(); + } + + at::Scalar toScalar() const { + if (isDouble()) + return toDouble(); + else if (isInt()) + return toInt(); + else if (isComplexDouble()) + return toComplexDouble(); + else if (isBool()) + return toBool(); + else if (isSymInt()) + return toSymInt(); + else if (isSymFloat()) + return toSymFloat(); + else if (isSymBool()) + return toSymBool(); + TORCH_CHECK(false, "IValue is not a Scalar"); + } + + // Device + IValue(c10::Device d) : tag(Tag::Device) { + payload.u.as_device.type = d.type(); + payload.u.as_device.index = d.index(); + } + bool isDevice() const { + return Tag::Device == tag; + } + c10::Device toDevice() const { + AT_ASSERT(isDevice()); + return c10::Device(payload.u.as_device.type, payload.u.as_device.index); + } + + // Stream + IValue(c10::Stream s) : tag(Tag::Stream) { + auto v = c10::make_intrusive(s.pack3()); + payload.u.as_intrusive_ptr = v.release(); + } + c10::Stream toStream() &&; + c10::Stream toStream() const&; + bool isStream() const { + return Tag::Stream == tag; + } + + // ScalarType + IValue(ScalarType t) + : IValue(static_cast>(t)) {} + at::ScalarType toScalarType() const { + return static_cast(toInt()); + } + + // Layout + IValue(Layout l) : IValue(static_cast>(l)) {} + at::Layout toLayout() const { + return static_cast(toInt()); + } + + // MemoryFormat + IValue(MemoryFormat m) + : IValue(static_cast>(m)) {} + at::MemoryFormat toMemoryFormat() const { + return static_cast(toInt()); + } + + // QScheme + IValue(at::QScheme qscheme) : tag(Tag::Int) { + payload.u.as_int = static_cast(qscheme); + } + + at::QScheme toQScheme() const { + return static_cast(toInt()); + } + + // Dimname + IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {} + + at::Dimname toDimname() const { + return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef())); + } + + // Generator + IValue(at::Generator g) : tag(Tag::Generator) { + payload.u.as_intrusive_ptr = + null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl()); + } + bool isGenerator() const { + return Tag::Generator == tag; + } + at::Generator toGenerator() &&; + at::Generator toGenerator() const&; + + // for debugging + std::string tagKind() const { + switch (tag) { +#define DEFINE_CASE(x) \ + case Tag::x: \ + return #x; + TORCH_FORALL_TAGS(DEFINE_CASE) +#undef DEFINE_CASE + } + return "InvalidTag(" + std::to_string(static_cast(tag)) + ")"; + } + + // generic v.to() implementations + // that can be used in special functions like pop/push + // that use template meta-programming. + // prefer the directly named methods when you can, + // since they are simpler to understand + + // Note: if you get linker errors saying one of these is missing, + // change it to ... && = delete; and you will see better error messages for + // why However, we cannot commit this because some compiler versions barf on + // it. + template + T to() &&; + template + typename c10::detail::ivalue_to_const_ref_overload_return::type to() + const&; + + // ToOptional: convert a IValue to the Optional obj that accepts both T and + // None + template + std::optional toOptional(); + template + std::optional toOptional() const; + + /// @private [doxygen private] + /// this is a shallow comparison of two IValues to test the object identity + bool isSameIdentity(const IValue& rhs) const; + + // Computes the "official" string representation of an IValue. This produces a + // TorchScript expression that can be used to recreate an IValue with the same + // value (e.g. when we are printing constants in the serializer). + // + // Callers can use `customFormatter` to override how `repr()` prints out an + // IValue. This is useful if you have some other environment where you can + // look up values, and you want to print a reference to that environment (like + // the serializer's constant table). + // + // repr() is not necessarily defined on all objects! + std::ostream& repr( + std::ostream& stream, + std::function customFormatter) + const; + + // Computes an "informal" string representation of an IValue. This should be + // used for debugging, or servicing `print()`-like functions. + // This is different from `repr()` in that there is no expectation that we can + // exactly reconstruct an IValue from the output; feel free to use a + // concise/pretty form + TORCH_API friend std::ostream& operator<<(std::ostream& out, const IValue& v); + + bool isPtrType() const { + if (isTensor()) { + return payload.as_tensor.defined(); + } + return isIntrusivePtrLegacyBehavior(); + } + + /// @private [doxygen private] + const void* internalToPointer() const { + TORCH_INTERNAL_ASSERT( + isPtrType(), "Can only call internalToPointer() for pointer types"); + if (isTensor()) { + return payload.as_tensor.unsafeGetTensorImpl(); + } else { + return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() + ? payload.u.as_intrusive_ptr + : nullptr; + } + } + + template + TypePtr type() const; + + // Detect aliased tensors. + struct HashAliasedIValue { + size_t hashTensor(const at::Tensor& ten) const { + if (ten.is_sparse()) { + // COO sparse tensors have a "values" tensor and an "indices" tensor + // so this will detect overlap of sparse tensors that share a values + // tensor, but not sparse tensors that share an indices tensor. + return hashTensor(ten._values()); + } else if (ten.is_sparse_csr()) { + // COO sparse tensors have a "values" tensor and an "indices" tensor + // so this will detect overlap of sparse tensors that share a values + // tensor, but not sparse tensors that share an indices tensor. + return hashTensor(ten.values()); + } else if (!ten.has_storage()) { + // Opaque tensors such as the ones constructed by the MKL-DNN backend + // don't have storage so we just use their TensorImpls. + // TODO: Find way to expose alias info for opaque tensors. + return reinterpret_cast(ten.unsafeGetTensorImpl()); + } else { + return reinterpret_cast(ten.storage().unsafeGetStorageImpl()); + } + } + size_t operator()(const IValue& val) const { + if (val.isTensor()) { + return hashTensor(val.toTensor()); + } + // If it is not a Tensor, then two mutable IValues alias each other only + // if they are the same pointer. + return val.payload.u.as_int; + } + }; + + struct CompAliasedIValues { + bool operator()(const IValue& lhs, const IValue& rhs) const { + return lhs.isAliasOf(rhs); + } + }; + + using HashAliasedIValues = + std::unordered_set; + using HashAliasedIValueMap = + std::unordered_map; + + struct HashIdentityIValue { + size_t operator()(const IValue& val) const { + return val.payload.u.as_int; + } + }; + + struct CompIdentityIValues { + bool operator()(const IValue& lhs, const IValue& rhs) const { + return lhs.is(rhs); + } + }; + + using HashIdentityIValues = + std::unordered_set; + using HashIdentityIValueMap = + std::unordered_map; + + // Chechs if this and rhs has a subvalues in common. + // [t1,t2] and [t2, t3] returns true. + bool overlaps(const IValue& rhs) const; + + // Inserts all subvalues of this in subValues. + void getSubValues(HashAliasedIValues& subValues) const; + + // Apply visitor to every subvalue. + // TODO: There are several places that recurse over IValue. This is fragile. + // This visitor should be used to recurse over ivalues. + void visit(const std::function& visitor) const; + IValue deepcopy(std::optional device = std::nullopt) const; + IValue deepcopy( + HashIdentityIValueMap& memo, + std::optional device = std::nullopt) const; + + private: + static c10::intrusive_ptr_target* null_to_undefined_tensor( + c10::intrusive_ptr_target* p) { + return p ? p + : static_cast( + c10::UndefinedTensorImpl::singleton()); + } + + static bool ptrEqual(const IValue& lhs, const IValue& rhs); + // NOTE: IValue tags are intentionally private. In the future we may encode + // this value different (e.g. using NaN boxing), and this would make it more + // costly to determine the tag for all types vs just determining if something + // is a particular type. Instead we want clients to use the `isX` methods when + // possible. If for performance reasons you really, absolutely, must have a jump + // table, then we can revisit this. + enum class Tag : uint32_t { +#define DEFINE_TAG(x) x, + TORCH_FORALL_TAGS(DEFINE_TAG) +#undef DEFINE_TAG + }; + +#define COUNT_TAG(x) 1 + + static constexpr auto kNumTags = TORCH_FORALL_TAGS(COUNT_TAG) 0; +#undef COUNT_TAG + + template < + class T, + class NullType = c10::detail::intrusive_target_default_null_type> + c10::intrusive_ptr moveToIntrusivePtr(); + template < + typename T, + class NullType = c10::detail::intrusive_target_default_null_type> + c10::intrusive_ptr toIntrusivePtr() const; + + void destroy() { + // We carefully construct this call to both 1) avoid UB by using + // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable + // the compiler to generate the same code for each case. It is + // surprisingly difficult to get this right. + if (isTensor() || isIntrusivePtr()) { + c10::intrusive_ptr_target* p = isTensor() + ? payload.as_tensor.unsafeGetTensorImpl() + : payload.u.as_intrusive_ptr; + c10::intrusive_ptr:: + reclaim(p); + // No need to make this destructor call! + // payload.as_tensor.~Tensor(); + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept { + if (rhs.isTensor()) { + new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // rhs.payload.as_tensor.~Tensor(); + } else { + payload.u = rhs.payload.u; + } + tag = rhs.tag; + rhs.clearToNone(); + } + + void clearToNone() noexcept { + payload.u.as_int = 0; + tag = Tag::None; + } + + private: + // This is the source of truth for isIntrusivePtr; edit results here + // as needed and isIntrusivePtr will pick them up. + // NOLINTBEGIN(bugprone-branch-clone) + static constexpr bool isIntrusivePtrConstexpr(Tag tag) { + switch (tag) { + case Tag::None: + return false; + case Tag::Tensor: + return false; + case Tag::Storage: + return true; + case Tag::Generator: + return true; + case Tag::Double: + return false; + case Tag::ComplexDouble: + return true; + case Tag::Int: + return false; + case Tag::SymInt: + return true; + case Tag::SymFloat: + return true; + case Tag::SymBool: + return true; + case Tag::Bool: + return false; + case Tag::Tuple: + return true; + case Tag::String: + return true; + case Tag::Blob: + return true; + case Tag::GenericList: + return true; + case Tag::GenericDict: + return true; + case Tag::Future: + return true; + case Tag::Await: + return true; + case Tag::Device: + return false; + case Tag::Stream: + return true; + case Tag::Object: + return true; + case Tag::PyObject: + return true; + case Tag::Uninitialized: + return false; + case Tag::Capsule: + return true; + case Tag::RRef: + return true; + case Tag::Quantizer: + return true; + case Tag::Enum: + return true; + } + return false; + } + // NOLINTEND(bugprone-branch-clone) + + public: + // Don't edit this just to add results for new tags; edit + // isIntrusivePtrConstexpr above. + bool isIntrusivePtr() const { + // Implementation NOTE: the switch in isIntrusivePtrConstexpr + // above is the previous production implementation of this + // function. We observed that, at least on x86_64, the generated + // instruction sequence was a similar bit vector test to what we + // have manually implemented below, except that there was an extra + // "bounds check" branch confirming, essentially, that `tag < + // kNumTags` and providing a consistent result in that case. We + // don't care about the result if tag is out of bounds, so we'd + // like to eliminate that comparison and branch; manually + // implementing this function as a bit test is the simplest way I + // could find to accomplish that elimination. + static constexpr uint32_t kTruthTableBitVector = +#define TRUTH_TABLE_ENTRY(tag) \ + (uint32_t(isIntrusivePtrConstexpr(Tag::tag)) << uint32_t(Tag::tag)) | + TORCH_FORALL_TAGS(TRUTH_TABLE_ENTRY) +#undef TRUTH_TABLE_ENTRY + 0; + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(tag) < kNumTags, + "unexpected tag ", + static_cast(tag)); + return kTruthTableBitVector & (1 << (uint32_t(tag) % 32)); + } + + // Storage and Generator were treated specially when + // is_intrusive_ptr was stored as explicit state. This getter + // preserves the old behavior for use with WeakIValue for now. + bool isIntrusivePtrLegacyBehavior() const { + if (tag == Tag::Storage || tag == Tag::Generator) { + return payload.u.as_intrusive_ptr != + c10::UndefinedTensorImpl::singleton(); + } else { + return isIntrusivePtr(); + } + } + + union Payload { + // [TriviallyCopyablePayload] + // We use a nested union here so that we can make the copy easy + // and efficient in the non-tensor (i.e., trivially copyable) + // case. Specifically, we do not have to do a switch-on-tag to + // figure out which union member to assign; we can just use + // TriviallyCopyablePayload::operator=. + union TriviallyCopyablePayload { + TriviallyCopyablePayload() : as_int(0) {} + int64_t as_int; + double as_double; + bool as_bool; + // Invariant: never nullptr; null state is represented as + // c10::UndefinedTensorImpl::singleton() for consistency of + // representation with Tensor. + c10::intrusive_ptr_target* as_intrusive_ptr; + struct { + c10::DeviceType type; + DeviceIndex index; + } as_device; + } u; + static_assert(std::is_trivially_copyable_v); + at::Tensor as_tensor; + Payload() : u() {} + Payload(const Payload&) = delete; + Payload(Payload&&) = delete; + Payload& operator=(const Payload&) = delete; + Payload& operator=(Payload&&) = delete; + // NOLINTNEXTLINE(modernize-use-equals-default) + ~Payload() {} + }; + + IValue(const Payload& p, Tag t) : tag(t) { + if (isTensor()) { + new (&payload.as_tensor) at::Tensor(p.as_tensor); + } else { + payload.u = p.u; + } + } + + template + struct TagType {}; + + friend MaybeOwnedTraits; + + Payload payload; + Tag tag{IValue::Tag::None}; + friend struct WeakIValue; +}; + +struct TORCH_API WeakIValue final { + WeakIValue() = default; + + WeakIValue(const WeakIValue& rhs) + : payload(rhs.payload), + tag(rhs.tag), + is_intrusive_ptr(rhs.is_intrusive_ptr) { + if (is_intrusive_ptr && + payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + } + } + WeakIValue(const IValue& rhs) + : tag(rhs.tag), is_intrusive_ptr(rhs.isIntrusivePtrLegacyBehavior()) { + if (rhs.isTensor()) { + payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); + is_intrusive_ptr = true; + } else { + payload = rhs.payload.u; + } + if (is_intrusive_ptr) { + if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + } + } + } + WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() { + swap(rhs); + } + ~WeakIValue() { + if (is_intrusive_ptr && + payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr); + } + } + WeakIValue& operator=(WeakIValue&& rhs) & noexcept { + WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None + return *this; + } + WeakIValue& operator=(WeakIValue const& rhs) & { + WeakIValue(rhs).swap(*this); + return *this; + } + void swap(WeakIValue& rhs) noexcept { + std::swap(payload, rhs.payload); + std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); + std::swap(tag, rhs.tag); + } + + bool isSameIdentity(const WeakIValue& rhs) const { + return payload.as_int == rhs.payload.as_int && tag == rhs.tag && + is_intrusive_ptr == rhs.is_intrusive_ptr; + } + + IValue lock() const { + if (!is_intrusive_ptr) { + IValue::Payload newPayload; + newPayload.u = payload; + return IValue(newPayload, tag); + } + if (IValue::Tag::Tensor == tag) { + auto temp = + c10::weak_intrusive_ptr:: + reclaim(static_cast(payload.as_intrusive_ptr)); + c10::intrusive_ptr ip( + temp.lock()); + temp.release(); + if (!ip) { + return IValue(); + } else { + return IValue(at::Tensor(std::move(ip))); + } + } else { + auto temp = c10::weak_intrusive_ptr::reclaim( + payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? nullptr + : payload.as_intrusive_ptr); + IValue::Payload pl; + pl.u.as_intrusive_ptr = temp.lock().release(); + temp.release(); + if (!pl.u.as_intrusive_ptr) { + return IValue(); + } else { + return IValue(pl, tag); + } + } + } + + size_t use_count() const noexcept { + if (!is_intrusive_ptr) { + return 1; + } + auto temp = c10::weak_intrusive_ptr< + c10::intrusive_ptr_target, + c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr); + size_t result = temp.use_count(); + temp.release(); + return result; + } + + size_t weak_use_count() const noexcept { + if (!is_intrusive_ptr) { + return 1; + } + auto temp = c10::weak_intrusive_ptr< + c10::intrusive_ptr_target, + c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr); + size_t result = temp.weak_use_count(); + temp.release(); + return result; + } + size_t hash() const { + return payload.as_int; + } + + private: + using Payload = IValue::Payload::TriviallyCopyablePayload; + Payload payload; + IValue::Tag tag{IValue::Tag::None}; + bool is_intrusive_ptr{false}; +}; + +// An owning pointer to a type. When the type is class type, it requires a pair +// of shared_ptrs to the class type and its owning CU, so that the class type is +// guaranteed to stay alive as long as we hold this object. +struct TORCH_API StrongTypePtr { + StrongTypePtr(std::shared_ptr cu, TypePtr type); + + std::shared_ptr cu_; + TypePtr type_; +}; + +// [Constant Object Weak CompilationUnit Reference] +// A non owning pointer to a type. When a class get inserted as a constant +// into a graph, if we used a strong pointer we would have a circular reference +// from Object -> CompilationUnit and CompilationUnit -> Graph (which owns the +// Constant Object) +struct TORCH_API WeakTypePtr { + WeakTypePtr(std::weak_ptr cu, TypePtr type); + + std::weak_ptr cu_; + TypePtr type_; +}; + +// internal build errors with std::variant :/ +struct WeakOrStrongCompilationUnit { + explicit WeakOrStrongCompilationUnit( + std::shared_ptr shared_cu) + : strong_ptr_(std::move(shared_cu)), weak_ptr_(std::nullopt) {} + + explicit WeakOrStrongCompilationUnit( + std::weak_ptr weak_cu) + : strong_ptr_(std::nullopt), weak_ptr_(std::move(weak_cu)) {} + + std::shared_ptr getStrongRefOrThrow() const { + TORCH_INTERNAL_ASSERT(strong_ptr_.has_value()); + return *strong_ptr_; + } + + std::weak_ptr getWeakRefOrThrow() const { + TORCH_INTERNAL_ASSERT(weak_ptr_.has_value()); + return *weak_ptr_; + } + + bool holdingStrongRef() const { + return strong_ptr_.has_value(); + } + + bool holdingEmptyStrongRef() const { + return strong_ptr_ == nullptr; + } + + std::optional> strong_ptr_; + std::optional> weak_ptr_; +}; + +// An Object will hold a non-owning Compilation Unit reference if it is a +// Constant in the graph and a Owning reference otherwise +struct TORCH_API WeakOrStrongTypePtr { + explicit WeakOrStrongTypePtr(WeakTypePtr weak) + : cu_(WeakOrStrongCompilationUnit(std::move(weak.cu_))), + type_(std::move(weak.type_)) {} + explicit WeakOrStrongTypePtr(StrongTypePtr strong) + : cu_(WeakOrStrongCompilationUnit(std::move(strong.cu_))), + type_(std::move(strong.type_)) {} + explicit WeakOrStrongTypePtr(WeakOrStrongCompilationUnit cu, TypePtr type) + : cu_(std::move(cu)), type_(std::move(type)) {} + WeakTypePtr asWeakTypePtr() const; + + WeakOrStrongCompilationUnit cu_; + TypePtr type_; + + bool holds_strong_ref() const { + return cu_.holdingStrongRef(); + } + + bool holds_empty_strong_ref() const { + return cu_.holdingEmptyStrongRef(); + } +}; + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue_inl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..1251c4c0c210dd139b72491763122176184ce67e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue_inl.h @@ -0,0 +1,2569 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +struct Function; +struct CompilationUnit; +} // namespace jit +TORCH_API bool isCustomClass(const c10::IValue& v); +} // namespace torch +namespace c10 { +struct IValue; +struct ClassType; +struct TupleType; +struct EnumType; +struct InferredType; + +// For custom class __init__ registration, we need to pass in a function +// that looks like this: [](IValue x, args...) + +// However, make_boxed_from_unboxed_functor.h automatically sets the input types +// of the function by introspecting the types of the functor (which is IValue in +// this case). However, we need the type it binds to be Foo. + +// Instead, we pass in a lambda [](ivalue_holder x, args...) from +// which getTypePtr can recover the original class pointer. + +template +struct tagged_capsule { + IValue ivalue; +}; + +template +c10::intrusive_ptr IValue::moveToIntrusivePtr() { + auto t = c10::intrusive_ptr::reclaim( + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); + clearToNone(); + return t; +} +template +c10::intrusive_ptr IValue::toIntrusivePtr() const { + if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { + return c10::intrusive_ptr(); + } + c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); + return c10::intrusive_ptr::reclaim( + static_cast(payload.u.as_intrusive_ptr)); +} + +template +intrusive_ptr static_intrusive_pointer_cast(intrusive_ptr r) { + return intrusive_ptr::reclaim(static_cast(r.release())); +} + +template +intrusive_ptr dynamic_intrusive_pointer_cast(intrusive_ptr r) { + return intrusive_ptr::reclaim(dynamic_cast(r.release())); +} + +inline c10::intrusive_ptr IValue::toFuture() && { + AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toFuture() const& { + AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toAwait() && { + AT_ASSERT(isAwait(), "Expected Await but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toAwait() const& { + AT_ASSERT(isAwait(), "Expected Await but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toRRef() && { + AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toRRef() const& { + AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toQuantizer() && { + AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toQuantizer() const& { + AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toString() && { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toString() const& { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toObject() && { + AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toObject() const& { + AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue:: + toPyObjectHolder() && { + TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toPyObjectHolder() + const& { + TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toEnumHolder() && { + TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toEnumHolder() const& { + TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); + return toIntrusivePtr(); +} +inline c10::complex IValue::toComplexDouble() const { + TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind()); + auto ptr = toIntrusivePtr(); + return (*ptr).val; +} +inline at::Tensor IValue::toTensor() && { + if (C10_UNLIKELY(!isTensor())) { + reportToTensorTypeError(); + } + auto result = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + clearToNone(); + return result; +} +inline at::Tensor& IValue::toTensor() & { + if (C10_UNLIKELY(!isTensor())) { + reportToTensorTypeError(); + } + return payload.as_tensor; +} +inline const at::Tensor& IValue::toTensor() const& { + if (C10_UNLIKELY(!isTensor())) { + reportToTensorTypeError(); + } + return payload.as_tensor; +} +inline c10::Storage IValue::toStorage() && { + AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); + return c10::Storage( + moveToIntrusivePtr()); +} +inline c10::Storage IValue::toStorage() const& { + AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); + return c10::Storage(toIntrusivePtr()); +} +inline c10::Stream IValue::toStream() && { + AT_ASSERT(isStream(), "Expected Stream but got ", tagKind()); + auto ptr = toIntrusivePtr(); + return c10::Stream::unpack3((*ptr).val.stream_id, + (*ptr).val.device_index, + (*ptr).val.device_type); +} +inline c10::Stream IValue::toStream() const& { + AT_ASSERT(isStream(), "Expected Stream but got ", tagKind()); + auto ptr = toIntrusivePtr(); + return c10::Stream::unpack3((*ptr).val.stream_id, + (*ptr).val.device_index, + (*ptr).val.device_type); +} +inline c10::intrusive_ptr IValue::toBlob() && { + AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toBlob() const& { + AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); + return toIntrusivePtr(); + ; +} +inline c10::intrusive_ptr IValue::toCapsule() && { + TORCH_INTERNAL_ASSERT(isCapsule()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toCapsule() const& { + TORCH_INTERNAL_ASSERT(isCapsule()); + return toIntrusivePtr(); +} +inline at::Generator IValue::toGenerator() && { + AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); + return at::Generator(moveToIntrusivePtr()); +} +inline at::Generator IValue::toGenerator() const& { + AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); + return at::Generator(toIntrusivePtr()); +} +inline c10::SymInt IValue::toSymInt() && { + AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); + if (isSymInt()) { + return c10::SymInt(moveToIntrusivePtr()); + } else { + return c10::SymInt(payload.u.as_int); + } +} +inline c10::SymInt IValue::toSymInt() const& { + AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); + if (isSymInt()) { + return c10::SymInt(toIntrusivePtr()); + } else { + return c10::SymInt(payload.u.as_int); + } +} +inline c10::SymFloat IValue::toSymFloat() && { + AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); + if (isSymFloat()) { + return c10::SymFloat(moveToIntrusivePtr()); + } else { + return c10::SymFloat(payload.u.as_double); + } +} +inline c10::SymFloat IValue::toSymFloat() const& { + AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); + if (isSymFloat()) { + return c10::SymFloat(toIntrusivePtr()); + } else { + return c10::SymFloat(payload.u.as_double); + } +} +inline c10::SymBool IValue::toSymBool() && { + AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind()); + if (isSymBool()) { + return c10::SymBool(moveToIntrusivePtr()); + } else { + return c10::SymBool(payload.u.as_bool); + } +} + +inline c10::SymBool IValue::toSymBool() const& { + AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind()); + if (isSymBool()) { + return c10::SymBool(toIntrusivePtr()); + } else { + return c10::SymBool(payload.u.as_bool); + } +} + +namespace ivalue { + +void TORCH_API +checkCustomClassType(const ClassType* expected_type, const Type* actual_type); + +template +using Shared = c10::intrusive_ptr; + +// string +struct TORCH_API ConstantString final : c10::intrusive_ptr_target { + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::string str_; + + public: + ConstantString(std::string str) : str_(std::move(str)) {} + ConstantString(std::string_view str) : str_(std::string(str)) {} + static c10::intrusive_ptr create(std::string str_); + static c10::intrusive_ptr create(std::string_view str_); + static c10::intrusive_ptr create(const char* str_); + + const std::string& string() const { + return str_; + } + std::string_view string_view() const { + return str_; + } + + operator const std::string&() const { + return string(); + } + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const ConstantString& v); +}; + +struct Future; + +struct TORCH_API TupleElements { + private: + size_t inlineSize_; + // We represent TupleElements this way to save doing a heap + // allocation in the common (at least for unpickling) case where we + // have only 3 elements. We have our own union instead of + // c10::SmallVector because c10::SmallVector always + // stores the begin/end/capacity pointers, which would be a waste of + // space in our use case. + union { + std::vector elementsVector_; + // Don't want to declare a std::array because the convenient + // iteration and size members are a footgun in this case -- the + // actual size of the array may be smaller than 3! + // NOLINTNEXTLINE(*c-arrays*) + IValue elementsInline_[3]; + }; + + void destroyInline() { + for (const auto ii : c10::irange(inlineSize_)) { + elementsInline_[ii].~IValue(); + } + } + public: + + using iterator = IValue*; + using const_iterator = const IValue*; + + TupleElements() : inlineSize_(0) { + new (&elementsVector_) std::vector(); + } + + explicit TupleElements(std::vector elements) + : inlineSize_(0), elementsVector_(std::move(elements)) {} + + explicit TupleElements(c10::ArrayRef elements) + : inlineSize_(elements.size() <= 3 ? elements.size() : 0) { + switch (inlineSize_) { + case 3: + new (&elementsInline_[2]) IValue(elements[2]); + [[fallthrough]]; + case 2: + new (&elementsInline_[1]) IValue(elements[1]); + [[fallthrough]]; + case 1: + new (&elementsInline_[0]) IValue(elements[0]); + break; + case 0: + new (&elementsVector_) std::vector(elements.begin(), elements.end()); + break; + } + } + + explicit TupleElements(IValue&& e1) + : inlineSize_(1) { + new (&elementsInline_[0]) IValue(std::move(e1)); + } + + explicit TupleElements(IValue&& e1, IValue&& e2) + : inlineSize_(2) { + new (&elementsInline_[0]) IValue(std::move(e1)); + new (&elementsInline_[1]) IValue(std::move(e2)); + } + + explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3) + : inlineSize_(3) { + new (&elementsInline_[0]) IValue(std::move(e1)); + new (&elementsInline_[1]) IValue(std::move(e2)); + new (&elementsInline_[2]) IValue(std::move(e3)); + } + + ~TupleElements() { + if (inlineSize_) { + destroyInline(); + } else { + elementsVector_.~vector(); + } + } + + // It would be nice to make this noncopyable to prevent people from + // writing code like `auto output = + // forward(...).toTupleRef().elements()` (which does refcount bumps on + // each element, unlike the more efficient but verbose + // ``` + // auto outputIntrusivePtr = forward(...).toTuple(); + // const auto& output = outputIntrusivePtr->elements(); + // ``` + // ), but there is simply an overwhelming amount of code that does + // it the inefficient way. + // See also operator std::vector below. + TupleElements(const TupleElements& rhs) + : inlineSize_(rhs.inlineSize_) { + if (rhs.inlineSize_) { + for (const auto ii : c10::irange(inlineSize_)) { + new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); + } + } else { + new (&elementsVector_) std::vector(rhs.elementsVector_); + } + } + + TupleElements& operator=(const TupleElements& rhs) { + if (inlineSize_) { + if (rhs.inlineSize_) { + for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { + elementsInline_[ii] = rhs.elementsInline_[ii]; + } + if (rhs.inlineSize_ > inlineSize_) { + for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); + } + } else { + for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { + elementsInline_[ii].~IValue(); + } + } + } else { + destroyInline(); + new (&elementsVector_) std::vector(rhs.elementsVector_); + } + } else { + if (rhs.inlineSize_) { + elementsVector_.~vector(); + for (const auto ii : c10::irange(rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); + } + } else { + elementsVector_ = rhs.elementsVector_; + } + } + inlineSize_ = rhs.inlineSize_; + return *this; + } + + TupleElements(TupleElements&& rhs) noexcept + : inlineSize_(rhs.inlineSize_) { + if (inlineSize_) { + for (const auto ii : c10::irange(inlineSize_)) { + new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); + } + } else { + new (&elementsVector_) std::vector(std::move(rhs.elementsVector_)); + } + } + + TupleElements& operator=(TupleElements&& rhs) noexcept { + if (inlineSize_) { + if (rhs.inlineSize_) { + for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { + elementsInline_[ii] = std::move(rhs.elementsInline_[ii]); + } + if (rhs.inlineSize_ > inlineSize_) { + for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); + } + } else { + for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { + elementsInline_[ii].~IValue(); + } + } + } else { + destroyInline(); + new (&elementsVector_) std::vector(std::move(rhs.elementsVector_)); + } + } else { + if (rhs.inlineSize_) { + elementsVector_.~vector(); + for (const auto ii : c10::irange(rhs.inlineSize_)) { + new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); + } + } else { + elementsVector_ = std::move(rhs.elementsVector_); + } + } + inlineSize_ = rhs.inlineSize_; + return *this; + } + + [[nodiscard]] c10::ArrayRef asArrayRef() const { + if (inlineSize_) { + return c10::ArrayRef(elementsInline_, inlineSize_); + } else { + return elementsVector_; + } + } + + // Mimic implicit conversion from std::vector to ArrayRef. + operator c10::ArrayRef() const { + return asArrayRef(); + } + + static size_t hash(const TupleElements& v) { + return c10::hash>()(v.asArrayRef()); + } + + void setContents(std::vector&& contents) { + if (inlineSize_) { + destroyInline(); + new (&elementsVector_) std::vector(std::move(contents)); + inlineSize_ = 0; + } else { + elementsVector_ = std::move(contents); + } + } + + [[nodiscard]] bool empty() const { + return inlineSize_ ? false : elementsVector_.empty(); + } + + [[nodiscard]] size_t size() const { + return inlineSize_ ? inlineSize_ : elementsVector_.size(); + } + + [[nodiscard]] IValue& operator[](size_t idx) { + if (inlineSize_) { + return elementsInline_[idx]; + } else { + return elementsVector_[idx]; + } + } + + [[nodiscard]] const IValue& operator[](size_t idx) const { + if (inlineSize_) { + return elementsInline_[idx]; + } else { + return elementsVector_[idx]; + } + } + + [[nodiscard]] IValue& at(size_t idx) { + if (inlineSize_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); + TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); + return elementsInline_[idx]; + } else { + return elementsVector_.at(idx); + } + } + + [[nodiscard]] const IValue& at(size_t idx) const { + if (inlineSize_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); + TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); + return elementsInline_[idx]; + } else { + TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size()); + return elementsVector_.at(idx); + } + } + + [[nodiscard]] iterator begin() { + if (inlineSize_) { + return elementsInline_; + } else { + return elementsVector_.data(); + } + } + + [[nodiscard]] iterator end() { + if (inlineSize_) { + return elementsInline_ + inlineSize_; + } else { + return elementsVector_.data() + elementsVector_.size(); + } + } + + [[nodiscard]] const_iterator begin() const { + if (inlineSize_) { + return elementsInline_; + } else { + return elementsVector_.data(); + } + } + + [[nodiscard]] const_iterator end() const { + if (inlineSize_) { + return elementsInline_ + inlineSize_; + } else { + return elementsVector_.data() + elementsVector_.size(); + } + } + + [[nodiscard]] const_iterator cbegin() const { + return begin(); + } + + [[nodiscard]] const_iterator cend() const { + return end(); + } + + [[nodiscard]] std::vector vec() const& { + return asArrayRef().vec(); + } + + [[nodiscard]] IValue& back() { + return *(end() - 1); + } + + [[nodiscard]] const IValue& back() const { + return *(end() - 1); + } + + [[nodiscard]] std::vector vec() && { + std::vector result; + result.reserve(size()); + for (auto&& iv : *this) { + result.push_back(std::move(iv)); + } + return result; + } + + // More compatibility shims for the overwhelming amount of code that + // likes to copy tuple elements into a vector; see comment above the + // copy constructor. + operator std::vector() const & { + return vec(); + } + + operator std::vector() && { + return vec(); + } +}; + +template +struct TupleTypeFactory {}; + +template <> +struct TORCH_API TupleTypeFactory { + static TupleTypePtr create(std::vector types) { + return TupleType::create(std::move(types)); + } + static TupleTypePtr fallback(const Type& type); +}; + +template <> +struct TORCH_API TupleTypeFactory { + static DynamicTypePtr create(const std::vector& elemTypes); + static DynamicTypePtr fallback(const Type&); +}; + +struct TORCH_API Tuple : c10::intrusive_ptr_target { + private: + TupleElements elements_; + mutable c10::TypePtr type_; // lazily computed for unnamed tuples + + public: + // named tuples have additional type information, so we + // directly create them tagged + static c10::intrusive_ptr createNamed( + std::vector elements_, + c10::TypePtr type_) { + return c10::make_intrusive(std::move(elements_), std::move(type_)); + } + + static c10::intrusive_ptr createNamed( + TupleElements elements_, + std::shared_ptr type_) { + return c10::make_intrusive(std::move(elements_), std::move(type_)); + } + + static c10::intrusive_ptr createNamed( + std::initializer_list elements_, + std::shared_ptr type_) { + return createNamed(TupleElements(c10::ArrayRef(elements_)), std::move(type_)); + } + + // MSVC apparently can't disambiguate the other two overloads of + // create when passed an initializer_list without this. + static c10::intrusive_ptr create(std::initializer_list elements_) { + return create(c10::ArrayRef(elements_)); + } + + static c10::intrusive_ptr create(std::vector elements_) { + return c10::make_intrusive(std::move(elements_)); + } + + static c10::intrusive_ptr create(TupleElements elements_) { + return c10::make_intrusive(std::move(elements_)); + } + + static c10::intrusive_ptr create(c10::ArrayRef elements_) { + return create(TupleElements(elements_)); + } + + static c10::intrusive_ptr create(IValue e1) { + return c10::make_intrusive(std::move(e1)); + } + + static c10::intrusive_ptr create(IValue e1, IValue e2) { + return c10::make_intrusive(std::move(e1), std::move(e2)); + } + + static c10::intrusive_ptr create(IValue e1, IValue e2, IValue e3) { + return c10::make_intrusive(std::move(e1), std::move(e2), std::move(e3)); + } + + private: + // Workaround inability to use `>` operator in template argument list. + template + static constexpr bool hasMoreThanThreeArgs() { + return sizeof...(Args) > 3; + } + + public: + template + static c10::intrusive_ptr create(Args&&... elements_) { + switch (sizeof...(Args)) { + case 1: + case 2: + case 3: + return create(IValue(std::forward(elements_))...); + default: + return create( + std::vector{IValue(std::forward(elements_))...}); + } + } + + // Again, it would be nice to make this noncopyable, but there's a + // lot of extant code that copies Tuples. + // Tuple(const Tuple& rhs) = delete; + + const TupleElements& elements() const& { + return elements_; + } + + TupleElements elements() && { + return std::move(elements_); + } + + void setElements(std::vector&& elements) { + elements_.setContents(std::move(elements)); + } + + void setElements(TupleElements&& elements) { + elements_ = std::move(elements); + } + + void unsafeSetElement(size_t idx, const IValue& element) { + elements_[idx] = element; + } + + void unsafeSetElement(size_t idx, IValue&& element) { + elements_[idx] = std::move(element); + } + + size_t size() const { + return elements_.size(); + } + + template + std::shared_ptr type() const { + if (!type_) { + type_ = TupleTypeFactory::create(fmap(elements(), [&](const IValue& v) { + return v.type(); + })); + } + if (auto t = type_->cast()) { + return t; + } + return TupleTypeFactory::fallback(*type_); + } + + static size_t hash(const Tuple& t) { + return c10::get_hash(t.elements()); + } + + TORCH_API friend bool operator==( + const ivalue::Tuple& lhs, + const ivalue::Tuple& rhs); + + private: + // NOTE: If we try to avoid the overloads without + // `std::shared_ptr type` by defaulting it to nullptr, we + // end up having to call (part of) the shared_ptr destructor for + // `type` even though we should know statically it won't do + // anything. + explicit Tuple(std::vector elements) + : elements_(std::move(elements)){} + + explicit Tuple(std::vector elements, c10::TypePtr type) + : elements_(std::move(elements)), type_(std::move(type)) {} + + explicit Tuple(TupleElements&& elements) + : elements_(std::move(elements)) {} + + explicit Tuple(TupleElements&& elements, std::shared_ptr type) + : elements_(std::move(elements)), type_(std::move(type)) {} + + explicit Tuple(IValue&& e1) + : elements_(std::move(e1)) {} + + explicit Tuple(IValue&& e1, std::shared_ptr type) + : elements_(std::move(e1)), type_(std::move(type)) {} + + explicit Tuple(IValue&& e1, IValue&& e2) + : elements_(std::move(e1), std::move(e2)) {} + + explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr type) + : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {} + + explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3) + : elements_(std::move(e1), std::move(e2), std::move(e3)) {} + + explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr type) + : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {} + + friend class c10::intrusive_ptr; +}; + +struct Object; +struct PyObjectHolder; +struct EnumHolder; +} // namespace ivalue + +// Future +struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { + private: + // Keep this private in order to force users to go through make_intrusive and + // thus prevent creating a Future that's not held by an intrusive_ptr. + explicit Future(TypePtr type, std::vector devices={}) + : type_(std::move(type)), + impl_(getTypeOfDevices(devices)), + devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {} + + friend c10::intrusive_ptr; + + struct FutureCallback { + std::function callback; + bool uses_future; // whether the Future& passed in is actually used + + template + FutureCallback(T callback, bool uses_future) + : callback(std::move(callback)), uses_future(uses_future) {} + }; + + public: + Future(const Future&) = delete; + Future(Future&&) = delete; + Future& operator=(const Future&) = delete; + Future& operator=(Future&&) = delete; + + // Destructor + // Explicitly destroy events under device guard, otherwise it can lead to + // extra context being created on device 0. Reason: python garbage collector + // calls this destructor, but python GC does not have a device context, so a + // "default" one (usually on device 0) could be created when we go down the + // line of event destroy. + ~Future() override { + while (!events_.empty()) { + c10::OptionalDeviceGuard deviceGuard(events_.back().device()); + events_.pop_back(); + } + } + + struct TORCH_API FutureError final : public std::exception { + explicit FutureError(std::string&& error_msg_) + : error_msg(std::move(error_msg_)) {} + + FutureError() = default; + + const char* what() const noexcept override { + return error_msg.c_str(); + } + + std::string error_msg; + }; + + /** + * Wait on the future until it completes. + */ + void wait() { + std::unique_lock lock(mutex_); + finished_cv_.wait(lock, [&]() -> bool { return completed_; }); + synchronizeWithCurrentStreams(); + } + + /** + * Wait on the future until it completes and throw an + * exception if an error exists. + */ + void waitAndThrow() { + wait(); + + if (eptr_) { + std::rethrow_exception(eptr_); + } + } + + /** + * Explicitly mark the future as completed with the output value. Optionally, + * the storages for all tensors in IValue can be passed as well. The DataPtrs + * of these storages are used to synchronize CUDA streams. If storages isn't + * given we will attempt to extract it from the value, if we need to (this + * happens if a non-empty set of devices was given to the constructor). Thus + * one only needs to provide storages when 1) they cannot be extracted through + * IValue::getSubValues() or through pickling in case of Python object; or + * when 2) customized storage extraction is more efficient. + */ + using WeakStorage = c10::weak_intrusive_ptr; + void markCompleted( + IValue value, + std::optional> storages = std::nullopt) { + // Start by performing all steps that can throw, before setting any field. + // Do this before even acquiring the mutex, because extractStorages might + // acquire the GIL, which could lead to a lock inversion with our mutex. + // See https://github.com/pytorch/pytorch/issues/58239. + std::vector actualStorages; + std::vector usedDevices; + try { + // FIXME We should always extract DataPtrs, in order to catch the case of + // users using CUDA values but forgetting to set devices, which currently + // leads to a silent synchronization/correctness issue. However, as this + // might worsen perf in CPU-only cases, we should only do so after careful + // benchmarks. + if (impl_.type() != c10::kCPU) { + actualStorages = + storages.has_value() ? std::move(*storages) : extractStorages(value); + usedDevices = getDevicesOfStorages(impl_, actualStorages); + ensureIsSubsetOfDevices(usedDevices, devices_); + } + } catch (const std::exception&) { + setError(std::current_exception()); + return; + } + + std::unique_lock lock(mutex_); + TORCH_CHECK( + !completed(), + "Attempting to mark a completed Future as complete again. Note that " + "a Future can only be marked completed once."); + + // Only set value_ and completed_ flag once all checks and preparation steps + // have returned successfully to allow for proper error propagation. + value_ = std::move(value); + completed_ = true; + + currentDevice_ = impl_.getDevice(); + storages_ = std::move(actualStorages); + for (const c10::Device& device : usedDevices) { + c10::Event event(impl_.type()); + event.record(impl_.getStream(device)); + events_.push_back(std::move(event)); + } + + std::vector cbs; + cbs.swap(callbacks_); + lock.unlock(); + + finished_cv_.notify_all(); + for (const auto& callback : cbs) { + invokeCallback(callback.callback, callback.uses_future); + } + } + + void markCompleted() { + markCompleted(IValue{}); + } + + void setError(std::exception_ptr eptr) { + std::unique_lock lock(mutex_); + setErrorInternal(std::move(eptr), lock); + } + + void setErrorIfNeeded(std::exception_ptr eptr) { + std::unique_lock lock(mutex_); + if (completed_) { + // This should be rare and shouldn't cause log spew. Its important to + // log errors and thats why we have this log here. + std::string msg = c10::str( + "Skipping setting following error on the Future since " + "it is already marked completed (this is not necessarily " + "an error):\n", + tryRetrieveErrorMessageInternal(std::move(eptr))); + if (eptr_) { + msg += c10::str( + ", \nOriginal exception:\n", + tryRetrieveErrorMessageInternal(eptr_)); + } + LOG(INFO) << msg; + return; + } else { + setErrorInternal(std::move(eptr), lock); + } + } + + // Get the result of the current future. + IValue value() { + std::unique_lock lock(mutex_); + AT_ASSERT(completed()); + if (eptr_) { + std::rethrow_exception(eptr_); + } + return value_; + } + + // This accessor should only be used if we know that the future is + // completed() with no error. + const IValue& constValue() const { + std::unique_lock lock(mutex_); + AT_ASSERT(completed()); + TORCH_INTERNAL_ASSERT( + !eptr_, + "value() accessor should only be used when future is not completed with ", + "an error, but future had the following error: ", + tryRetrieveErrorMessageInternal(eptr_) + ); + return value_; + } + + // This accessor should only be used if we know that the future is + // completed() with no error. + const std::vector& storages() const { + std::unique_lock lock(mutex_); + AT_ASSERT(completed()); + AT_ASSERT(!eptr_); + return storages_; + } + + /** + * Add a callback to the future. + * The callbacks will be executed once the future completes. + * If the future has already completed, + * this function will execute the callback immediately. + */ + template + void addCallback(T callback, bool uses_future = true) { + static_assert( + std::is_invocable_r_v, + "The callback must have signature void(Future&)"); + + std::unique_lock lock(mutex_); + if (completed()) { + lock.unlock(); + invokeCallback(callback, uses_future); + return; + } + callbacks_.emplace_back(std::move(callback), uses_future); + } + + /** + * Add a callback to the future, and return another Future to hold the return + * value of the callback. This is necessary when the callback provider needs + * to know for sure when the callback has finished. + */ + template + c10::intrusive_ptr then(T callback, TypePtr type) { + using IValueWithStorages = std::tuple>; + static_assert( + std::disjunction_v< + std::is_invocable_r, + std::is_invocable_r>, + "The callback must have signature IValue(Future&) or " + "std::tuple>(Future&)"); + + auto childFut = createInstance(::std::move(type)); + addCallback([childFut, + cb = std::move(callback)](Future& parentFut) { + try { + if constexpr (::std::is_convertible_v, IValueWithStorages>) { + auto [ivalue, storages] = cb(parentFut); + childFut->markCompleted(::std::move(ivalue), ::std::move(storages)); + } else { + childFut->markCompleted(cb(parentFut)); + } + } catch (std::exception&) { + childFut->setError(std::current_exception()); + } + }); + return childFut; + } + + template + c10::intrusive_ptr thenAsync(T callback, TypePtr type) { + static_assert( + std::is_invocable_r_v, T, Future&>, + "The callback must have signature c10::intrusive_ptr(Future&)"); + + auto childFut = createInstance(std::move(type)); + addCallback( + [childFut, cb = std::move(callback)](Future& parentFut) mutable { + c10::intrusive_ptr intermediateFut; + try { + intermediateFut = cb(parentFut); + } catch (std::exception&) { + childFut->setError(std::current_exception()); + return; + } + intermediateFut->addCallback( + [childFut = std::move(childFut)](Future& intermediateFut) { + if (intermediateFut.hasError()) { + childFut->setError(intermediateFut.exception_ptr()); + } else { + childFut->markCompleted( + intermediateFut.value(), intermediateFut.storages()); + } + }); + }); + return childFut; + } + + // Tries to retrieve the error message from std::exception_ptr. + std::string tryRetrieveErrorMessage() const { + TORCH_CHECK(hasError(), "No error present on the future."); + std::unique_lock lock(mutex_); + return tryRetrieveErrorMessageInternal(eptr_); + } + + // Check if the current future has completed + bool completed() const { + return completed_; + } + + bool hasValue() const { + std::unique_lock lock(mutex_); + return completed_ && !eptr_; + } + + bool hasError() const { + std::unique_lock lock(mutex_); + return eptr_ ? true : false; + } + + std::exception_ptr exception_ptr() const { + std::unique_lock lock(mutex_); + return eptr_; + } + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const Future& v); + + const TypePtr& elementType() const { + return type_; + } + + const std::vector& devices() const { + return devices_; + } + + // This method should be used when one intends to manually create a child + // future, for example when implementing a customized version of then(). + c10::intrusive_ptr createInstance(at::TypePtr type) { + return c10::make_intrusive(std::move(type), devices_); + } + + private: + + // This method should always be used when invoking a callback (regardless of + // how/when that happens) as it will ensure that the proper "environment" is + // set up before running the callback, as in, it will set up the CUDA streams, + // synchronize them with the value, and so on (if needed). + template + void invokeCallback(T& callback, bool uses_future) { + static_assert( + std::is_invocable_r_v, + "The callback must have signature void(Future&)"); + + // The synchronization performed below shouldn't be needed when the future + // is not used by the callback. + if (uses_future) { + c10::OptionalDeviceGuard deviceGuard(currentDevice_); + + std::vector streams; + streams.reserve(devices_.size()); + for (const c10::Device& device : devices_) { + streams.push_back(impl_.getStreamFromGlobalPool(device)); + } + c10::MultiStreamGuard streamGuard(streams); + synchronizeWithCurrentStreams(); + callback(*this); + } else { + callback(*this); + } + } + + // This method should be called before this future's value is used, as it + // ensures that the CUDA streams that are "current" at the callsite properly + // synchronize with the value. + void synchronizeWithCurrentStreams() { + for (c10::Event& event : events_) { + event.block(impl_.getStream(event.device())); + } + + for (const WeakStorage& weak_storage : storages_) { + c10::intrusive_ptr storage = weak_storage.lock(); + if (!storage) { + continue; + } + if (!storage->device().is_cpu()) { + impl_.recordDataPtrOnStream( + storage->data_ptr(), impl_.getStream(storage->device())); + } + } + } + + void setErrorInternal( + std::exception_ptr eptr, + std::unique_lock& lock) { + TORCH_CHECK( + !eptr_, + "Error already set on this Future: ", + tryRetrieveErrorMessageInternal(eptr_), + ", trying to set error: ", + tryRetrieveErrorMessageInternal(eptr)); + TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed"); + completed_ = true; + eptr_ = std::move(eptr); + + std::vector cbs; + cbs.swap(callbacks_); + lock.unlock(); + + finished_cv_.notify_all(); + for (const auto& callback : cbs) { + invokeCallback(callback.callback, callback.uses_future); + } + } + + // Tries to retrieve the error message from std::exception_ptr. + std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const { + try { + std::rethrow_exception(std::move(eptr)); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown Exception Type"; + } + } + + // Defined in ivalue.cpp. + static std::vector extractStorages( + const at::IValue& value); + + static std::vector getDevicesOfStorages( + const c10::impl::VirtualGuardImpl& impl, + const std::vector& storages) { + c10::DeviceIndex deviceCount = impl.deviceCount(); + std::vector isDeviceUsed(deviceCount, false); + for (const WeakStorage& weak_storage : storages) { + c10::intrusive_ptr storage = weak_storage.lock(); + if (!storage) { + continue; + } + c10::Device device = storage->device(); + if (!device.is_cpu()) { + TORCH_CHECK_VALUE( + device.type() == impl.type(), + "Expected all data ptrs to be on a device of type ", + impl.type(), + ", got one on device ", + device); + isDeviceUsed[device.index()] = true; + } + } + std::vector devices; + for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) { + if (isDeviceUsed[idx]) { + devices.emplace_back(impl.type(), idx); + } + } + return devices; + } + + static std::string formatSetOfDevices( + const std::vector& devices) { + if (devices.empty()) { + return "(none)"; + } + std::ostringstream oss; + oss << devices[0]; + for (const auto idx : c10::irange(1, devices.size())) { + if (idx == devices.size() - 1) { + oss << " and "; + } else { + oss << ", "; + } + oss << devices[idx]; + } + return oss.str(); + } + + static c10::DeviceType getTypeOfDevices( + const std::vector& devices) { + if (devices.empty()) { + return c10::kCPU; + } + c10::DeviceType deviceType = devices[0].type(); + for (const auto idx : c10::irange(1, devices.size())) { + TORCH_CHECK_VALUE( + devices[idx].type() == deviceType, + "Expected all devices to be of the same type, but got a mismatch between ", + devices[0], + " and ", + devices[idx]); + } + return deviceType; + } + + // We need devices to be sorted in order to use ensureIsSubsetOfDevices. + static std::vector sortAndDeduplicateDevices( + const c10::impl::VirtualGuardImpl& /*impl*/, + std::vector devices) { + std::sort( + devices.begin(), devices.end(), + [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); + // Deduplicate by compacting. + size_t targetIdx = 0; + for (const auto sourceIdx : c10::irange(devices.size())) { + TORCH_CHECK_VALUE( + devices[sourceIdx].has_index(), + "Expected devices to have indices, got ", devices[sourceIdx]); + if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) { + // It's a duplicate, skip it. + continue; + } + if (sourceIdx != targetIdx) { + devices[targetIdx] = devices[sourceIdx]; + } + targetIdx++; + } + // If there were duplicates there's now a gap at the end: trim it. Resizing + // requires the item type to be default-constructible (which c10::Device is + // not) because in principle it could be required to create new items. Since + // we know we'll shrink the vector, we provide a custom dummy value instead. + devices.resize(targetIdx, c10::Device(c10::kCPU)); + return devices; + } + + static void ensureIsSubsetOfDevices( + const std::vector& subset, + const std::vector& superset) { + // We assume the devices in both vectors have the same consistent type, and + // their indices are unique and sorted. + std::vector excessDevices; + std::set_difference( + subset.begin(), + subset.end(), + superset.begin(), + superset.end(), + std::back_inserter(excessDevices), + [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); + TORCH_CHECK_VALUE( + excessDevices.empty(), + "The result contained tensors residing on device(s) ", + formatSetOfDevices(excessDevices), + " which are not among the expected device(s) ", + formatSetOfDevices(superset)); + } + + mutable std::mutex mutex_; + std::atomic_bool completed_ = {false}; // is this future complete + std::condition_variable finished_cv_; + + IValue value_; // when finished the value + TypePtr type_; + std::vector callbacks_; + std::exception_ptr eptr_; + + // An upcast pointer to a virtual class which allows us to manipulate events, + // streams, ... in a generic way, without an explicit dependency on CUDA. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const c10::impl::VirtualGuardImpl impl_; + + // The device that was current when markCompleted was called, which we'll + // restore when invoking callbacks. It's optional because we'll only store it + // if the future completes successfully. + std::optional currentDevice_; + + // The events that correspond to the completion of the async I/O kernels. They + // are recorded on the appropriate streams when the future is marked completed + // and can then be queried/waited/blocked on. There is one event for each + // distinct device on which the value's tensors reside. + std::vector events_; + + // A cached version of the storages extracted from the value when the future + // is first marked completed. + std::vector storages_; + + // The bounding set of devices that this future, and any of its children, is + // allowed to use. This is a superset of the set of devices used by the events + // above. We need this to know what streams (for which devices) to set as + // current when invoking a callback, thus allowing the callback to use devices + // that the parent future didn't use. This field is set to the value provided + // in the constructor and will be "inherited" by all child futures. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::vector devices_; +}; + +struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target { + private: + explicit Await(TypePtr elType, std::function fn) + : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {} + + explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { } + + friend c10::intrusive_ptr; + + public: + Await(const Await&) = delete; + Await(Await&&) = delete; + Await& operator=(const Await&) = delete; + Await& operator=(Await&&) = delete; + ~Await() override = default; + + IValue wait() { + if (!completed_) { + TORCH_CHECK(fn_, "Incompleted Await: fn can't be None"); + value_ = fn_(); + completed_ = true; + args_ = {}; + } + return value_; + } + + IValue value() { + TORCH_CHECK(completed_, "Await must be completed"); + return value_; + } + + void setFn(std::function fn) { + fn_ = std::move(fn); + } + + bool completed() { + return completed_; + } + + void markCompleted(IValue value) { + value_ = std::move(value); + completed_ = true; + } + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const Await& v); + + const TypePtr& elementType() const { + return elType_; + } + + const TypePtr& type() const { + return type_; + } + + void setArgs(std::vector args) { + args_ = std::move(args); + } + + std::vector& args() { + return args_; + } + + private: + TypePtr elType_; + TypePtr type_; + std::vector args_; + std::function fn_; + IValue value_; + bool completed_{}; +}; + +// Input is a list of Futures with the same target type. +// Output is a Future to the List of completed Futures. +TORCH_API intrusive_ptr collectAll( + const c10::List>& srcs); +// Input is a List of Futures with the same target type. +// Output is a Future that will be updated with a seen value. +TORCH_API intrusive_ptr collectAny( + const c10::List>& srcs); + +// User-defined object. +struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { + public: + // In general, class types hold a shared_ptr to its owning CompilationUnit, + // so that its type and methods do not get deallocated while the class exists. + // However, the CompilationUnit holds ownership of the type's graphs, so + // inserting a constant object into a Graph would create a reference cycle if + // that constant object held a shared_ptr to its CU. For these objects we + // instatiate them with non-owning references to its CU + Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) { + slots_.resize(numSlots); + } + + Object(StrongTypePtr type, size_t numSlots) + : type_(WeakOrStrongTypePtr(std::move(type))) { + slots_.resize(numSlots); + } + + static c10::intrusive_ptr create( + WeakOrStrongTypePtr type, + size_t numSlots) { + return c10::make_intrusive(std::move(type), numSlots); + } + + static c10::intrusive_ptr create( + StrongTypePtr type, + size_t numSlots) { + return c10::make_intrusive(std::move(type), numSlots); + } + + static c10::intrusive_ptr create(ClassTypePtr classType, size_t numSlots); + + /** + * Slot API. + * + * Attributes are stored as a simple vector so that lookups are fast at + * runtime. A "slot" is just an index into that vector, which can be computed + * statically if you have access to the class type. Use this API if you are + * writing compiler stuff. + */ + void setSlot(size_t slot, IValue v) { + if (slot >= slots_.size()) { + // for module types, it is possible that the members of the class have + // expanded after the object was created. In this case, we expand + // the slots to the right size + resizeObject(slot); + } + slots_[slot] = std::move(v); + } + + const IValue& getSlot(size_t slot) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size()); + // NOTE: This lookup is fairly hot, so we use unchecked access to the + // vector. Errors should still be detectable with ASan. + return slots_[slot]; + } + + void unsafeRemoveSlot(size_t slot) { + TORCH_CHECK(slot < slots_.size()); + slots_.erase(slots_.begin() + static_cast(slot)); + } + + /** + * Attribute API. + * + * Wrappers around the slot stuff so that users can access attributes + * directly. Use this API if you are a user. + * + * Note: Unlike in Python, TorchScript must make a distinction between + * attributes (which are IValues) and methods (which are Methods). If you + * want a method, use `obj.type()->getMethod()` + */ + IValue getAttr(const std::string& name) const; + void setAttr(const std::string& name, IValue v); + // Remove attribute by name, caller is responsible for + // the safety of this operation + // We didn't remove the attribute in the type because the type + // might be shared by multiple objects. + // Therefore after removing attribute, the object is in an inconsistent + // state where it has more attribute types in its Type than + // the attribute slots it has, user needs to make sure the object + // has consistent by removing the attribute in type as well + void unsafeRemoveAttr(const std::string& name); + + std::string name() const; + + const std::vector& slots() const { + return slots_; + } + std::shared_ptr type() const; + + std::shared_ptr compilation_unit() { + if (type_.holds_strong_ref()) { + return type_.cu_.getStrongRefOrThrow(); + } else { + auto weak_ptr = type_.cu_.getWeakRefOrThrow(); + return std::shared_ptr(weak_ptr); + } + } + + c10::intrusive_ptr copy_to_weak_compilation_ref() const; + + void unsafe_make_weak_compilation_ref() { + type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr()); + } + + c10::intrusive_ptr copy() const; + + c10::intrusive_ptr deepcopy( + std::optional device = std::nullopt) const; + + c10::intrusive_ptr deepcopy( + IValue::HashIdentityIValueMap& memo, + std::optional device = std::nullopt) const; + + bool is_weak_compilation_ref() const { + return !type_.holds_strong_ref(); + } + + bool is_empty_strong_compilation_ref() const { + return type_.holds_empty_strong_ref(); + } + + private: + void resizeObject(size_t slot); + WeakOrStrongTypePtr type_; + std::vector slots_; +}; + +// virtual ivalue PyObjectHolder that hold a py::object, we make this virtual +// because the py::object and refcounting logic should happen in libtorch_python +// see concrete implementation in python_ivalue.h +struct ivalue::PyObjectHolder : c10::intrusive_ptr_target { + public: + virtual PyObject* getPyObject() = 0; + virtual c10::InferredType tryToInferType() = 0; + virtual IValue toIValue(const TypePtr& type, std::optional N = std::nullopt) = 0; + virtual std::string toStr() = 0; + virtual std::vector extractTensors() = 0; + + ~PyObjectHolder() override = default; +}; + +struct ivalue::EnumHolder : c10::intrusive_ptr_target { + public: + EnumHolder(std::shared_ptr type, std::string name, IValue value) + : type_(std::move(type)), + name_(std::move(name)), + value_(std::move(value)) {} + + bool is(const ivalue::EnumHolder& rhs) { + return *this == rhs; + } + + friend bool operator==( + const ivalue::EnumHolder& lhs, + const ivalue::EnumHolder& rhs); + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const ivalue::EnumHolder& v); + + TORCH_API const std::string& qualifiedClassName() const; + + const std::string& unqualifiedClassName() const; + + const std::string& name() const { + return name_; + } + + const IValue& value() const { + return value_; + } + + std::shared_ptr type() const { + return type_; + } + + private: + std::shared_ptr type_; + std::string name_; + IValue value_; +}; + +#undef TORCH_FORALL_TAGS + +namespace detail { + +struct _guarded_unsigned_long_unique_dummy final { + _guarded_unsigned_long_unique_dummy(int64_t){} +}; +using _guarded_unsigned_long = std::conditional_t< + std::is_same_v || + std::is_same_v, + _guarded_unsigned_long_unique_dummy, + unsigned long>; + +} // namespace detail + +inline ivalue::Object& IValue::toObjectRef() const { + AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference"); + return *static_cast(payload.u.as_intrusive_ptr); +} + +// note: when adding a DEFINE_TO case here you should also add a +// toX method to IValue. These named methods are much more discoverable +// than the to templated function. + +#define DEFINE_TO(T, method_name) \ + template <> \ + inline T IValue::to()&& { \ + return static_cast(std::move(*this).method_name()); \ + } \ + template <> \ + inline c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { \ + typedef c10::detail::ivalue_to_const_ref_overload_return::type return_type; \ + return static_cast(this->method_name()); \ + } + +DEFINE_TO(at::Tensor, toTensor) +DEFINE_TO(at::Storage, toStorage) +DEFINE_TO(c10::Stream, toStream) +DEFINE_TO(float, toDouble) +DEFINE_TO(double, toDouble) +DEFINE_TO(c10::complex, toComplexDouble) +DEFINE_TO(unsigned char, toInt) +DEFINE_TO(signed char, toInt) +DEFINE_TO(unsigned short, toInt) +DEFINE_TO(short, toInt) +DEFINE_TO(int, toInt) +DEFINE_TO(uint32_t, toInt) +DEFINE_TO(uint64_t, toInt) +DEFINE_TO(detail::_guarded_unsigned_long, toInt) +DEFINE_TO(int64_t, toInt) +DEFINE_TO(bool, toBool) +DEFINE_TO(c10::intrusive_ptr, toBlob) +DEFINE_TO(c10::intrusive_ptr, toString) +DEFINE_TO(c10::intrusive_ptr, toObject) +DEFINE_TO(at::Scalar, toScalar) +DEFINE_TO(c10::List, toIntList) +DEFINE_TO(c10::List, toSymIntList) +DEFINE_TO(c10::List, toDoubleList) +DEFINE_TO(c10::List>, toComplexDoubleList) +DEFINE_TO(c10::List, toBoolList) +DEFINE_TO(c10::List, toTensorList) +DEFINE_TO(c10::impl::GenericList, toList) +DEFINE_TO(c10::impl::GenericDict, toGenericDict) +DEFINE_TO(c10::intrusive_ptr, toTuple) +DEFINE_TO(std::string, toStringRef) +DEFINE_TO(std::string_view, toStringView) +DEFINE_TO(c10::intrusive_ptr, toFuture) +DEFINE_TO(c10::intrusive_ptr, toAwait) +DEFINE_TO(c10::intrusive_ptr, toRRef) +DEFINE_TO(c10::intrusive_ptr, toQuantizer) +DEFINE_TO(IValue, toIValue) +DEFINE_TO(c10::Device, toDevice) +DEFINE_TO(at::ScalarType, toScalarType) +DEFINE_TO(at::Layout, toLayout) +DEFINE_TO(at::MemoryFormat, toMemoryFormat) +DEFINE_TO(at::QScheme, toQScheme) +DEFINE_TO(at::Dimname, toDimname) +DEFINE_TO(at::Generator, toGenerator) +DEFINE_TO(c10::SymInt, toSymInt) +DEFINE_TO(c10::SymFloat, toSymFloat) +DEFINE_TO(c10::SymBool, toSymBool) + +template +struct _fake_type {}; + +// generic_to converts an IValue from a generic list or generic dict +// to a concrete list/dict type likelike List, Dict<...> or std::optional. +// Note that in the case of lists, this only works for IValue-based lists, +// i.e. not for int64_t, double, ... +// generic_to is an implementation detail of IValue::to and not +// supposed to be called directly. +// The _fake_type parameter allows us to overload +// based on the return type. +template +// TODO this is deprecated but we don't throw a warning because a lot of ops in +// native_functions.yaml still return std::vector. +// C10_DEPRECATED_MESSAGE("IValues based on std::vector are potentially slow +// and deprecated. Please use torch::List instead.") +std::vector generic_to(IValue ivalue, _fake_type>) { + // We need to do a deep copy of the vector because there might be other + // references to this same IValue that also use the list. We can't just + // move the elements out. + auto list = std::move(ivalue).template to>(); + std::vector result; + result.reserve(list.size()); + for (Elem v : list) { + result.push_back(std::move(v)); + } + return result; +} + +template +c10::intrusive_ptr IValue::toCustomClass() && { + static_assert( + std::is_base_of_v == true, + "toCustomClass requires that template parameter T must inherit " + "from torch::CustomClassHolder"); + auto obj = toObject(); + TORCH_CHECK( + obj->slots().size() == 1, + "Tried to cast IValue to custom class but it did " + "not contain a custom class!"); + const auto* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); + auto userObj = + c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); + return userObj; +} + +template +c10::intrusive_ptr IValue::toCustomClass() const& { + static_assert( + std::is_base_of_v == true, + "toCustomClass requires that template parameter T must inherit " + "from torch::CustomClassHolder"); + auto obj = toObject(); + TORCH_CHECK( + obj->slots().size() == 1, + "Tried to cast IValue to custom class but it did " + "not contain a custom class!"); + const auto* expected_type = c10::getCustomClassType>().get(); + ivalue::checkCustomClassType(expected_type, type().get()); + auto userObj = + c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); + return userObj; +} + +template +T generic_to(IValue ivalue, _fake_type) { + using ElemType = typename std::remove_pointer::type::element_type; + return std::move(ivalue).template toCustomClass(); +} + +template +tagged_capsule generic_to(IValue ivalue, _fake_type>) { + return tagged_capsule{std::move(ivalue)}; +} + +template +c10::List generic_to(IValue ivalue, _fake_type>) { + return impl::toTypedList(std::move(ivalue).toList()); +} + +template +static T createVectorLikeFromList(const c10::detail::ListImpl* impl) { + T result; + result.reserve(impl->list.size()); + for (const auto & i : impl->list) { + result.push_back(i.to()); + } + return result; +} + +template +static std::vector createVectorFromList(const c10::detail::ListImpl* impl) { + return createVectorLikeFromList>(impl); +} + +template +std::vector createVectorFromList(const c10::List& impl) { + std::vector result; + result.reserve(impl.size()); + for (size_t i = 0, N = impl.size(); i < N; ++i) { + result.push_back(impl[i]); + } + return result; +} + +template +OptionalArray generic_to(IValue ivalue, _fake_type>) { + if (ivalue.isNone()) { + return {}; + } + return createVectorFromList( + std::move(ivalue).template to>() + ); +} + +namespace detail { +template +std::array generic_to_array( + IValue ivalue, + _fake_type>, + std::index_sequence) { + // We need to do a deep copy of the array because there might be other + // references to this same IValue that also use the list. We can't just + // move the elements out. + auto list = std::move(ivalue).template to>(); + TORCH_CHECK( + list.size() == sizeof...(I), + "Tried to convert a List with ", + list.size(), + " elements to a fixed-size array of size ", + sizeof...(I)); + return {list[I]...}; +} +} // namespace detail + +template +std::array generic_to( + IValue ivalue, + _fake_type> ft) { + return detail::generic_to_array(ivalue, ft, std::make_index_sequence()); +} + +template +c10::Dict generic_to( + IValue ivalue, + _fake_type>) { + return impl::toTypedDict(std::move(ivalue).toGenericDict()); +} + +template +C10_DEPRECATED_MESSAGE( + "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") +std::unordered_map generic_to( + IValue ivalue, + _fake_type>) { + std::unordered_map specialized_dict; + + for (const auto& item : std::move(ivalue).toGenericDict()) { + specialized_dict[item.key().template to()] = item.value().template to(); + } + + return specialized_dict; +} + +template +std::optional generic_to(IValue ivalue, _fake_type>) { + if (ivalue.isNone()) { + return std::nullopt; + } + return std::move(ivalue).template to(); +} + +namespace detail { +template +Tuple generic_to_tuple_impl( + const ivalue::TupleElements& t, + std::index_sequence) { + return std::make_tuple( + t[INDEX].to::type>()...); +} +} // namespace detail + +template < + typename... Args, + typename Indices = std::make_index_sequence, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t> = nullptr> +std::tuple generic_to(const IValue& ivalue, _fake_type>) { + const auto& vals = ivalue.toTupleRef().elements(); + TORCH_CHECK(vals.size() == sizeof...(Args)); + return detail::generic_to_tuple_impl>(vals, Indices{}); +} + +template +inline T IValue::to() && { + return generic_to(std::move(*this), _fake_type{}); +} + +template <> +inline std::optional IValue::to() && { + // In the default implementation, the IValue is destroyed with std::move. + // But if the unboxed type is std::optional we cannot destroy + // the IValue. + return generic_to(*this, _fake_type>{}); +} + +template +inline typename c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { + return generic_to(*this, _fake_type{}); +} + +inline c10::List IValue::toIntList() && { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toIntList() const& { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toIntVector() const { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toIntVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toSymIntList() && { + AT_ASSERT( + isSymIntList() || isIntList(), + "Expected SymIntList or IntList but got ", + tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toSymIntList() const& { + AT_ASSERT( + isSymIntList() || isIntList(), + "Expected SymIntList or IntList but got ", + tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toSymIntVector() const { + AT_ASSERT(isSymIntList() || isIntList(), "Expected SymIntList or IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toSymIntVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline at::DimVector IValue::toDimVector() const { + AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toDimVector on null intrusive_ptr IValue"); + return createVectorLikeFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toDoubleList() && { + AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toDoubleList() const& { + AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toDoubleVector() const { + AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toDoubleVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List> IValue::toComplexDoubleList() && { + AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); + return c10::List>(moveToIntrusivePtr()); +} +inline c10::List> IValue::toComplexDoubleList() const& { + AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); + return c10::List>(toIntrusivePtr()); +} +inline std::vector> IValue::toComplexDoubleVector() const { + AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toComplexDoubleVector on null intrusive_ptr IValue"); + return createVectorFromList>( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toBoolList() && { + AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toBoolList() const& { + AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline c10::List IValue::toTensorList() && { + AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toTensorList() const& { + AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline std::vector IValue::toTensorVector() const { + AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toTensorVector on null intrusive_ptr IValue"); + return createVectorFromList( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List> IValue::toOptionalTensorList() && { + AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); + return c10::List>(moveToIntrusivePtr()); +} +inline c10::List> IValue::toOptionalTensorList() const& { + AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); + return c10::List>(toIntrusivePtr()); +} +inline std::vector> IValue::toOptionalTensorVector() const { + AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toOptionalTensorVector on null intrusive_ptr IValue"); + return createVectorFromList>( + static_cast(payload.u.as_intrusive_ptr)); +} +inline c10::List IValue::toList() && { + AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); + return c10::List(moveToIntrusivePtr()); +} +inline c10::List IValue::toList() const& { + AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); + return c10::List(toIntrusivePtr()); +} +inline c10::ArrayRef IValue::toListRef() const { + AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toListRef on null intrusive_ptr IValue"); + return static_cast(payload.u.as_intrusive_ptr) + ->list; +} +inline c10::Dict IValue::toGenericDict() && { + AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); + return c10::Dict(moveToIntrusivePtr()); +} +inline c10::Dict IValue::toGenericDict() const& { + AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); + return c10::Dict(toIntrusivePtr()); +} +inline c10::intrusive_ptr IValue::toTuple() && { + AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); + return moveToIntrusivePtr(); +} +inline c10::intrusive_ptr IValue::toTuple() const& { + AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); + return toIntrusivePtr(); +} +inline ivalue::Tuple& IValue::toTupleRef() const { + AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toTupleRef on null intrusive_ptr IValue"); + return *static_cast( + payload.u.as_intrusive_ptr); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Tuple) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} +template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t>> +inline IValue::IValue(const std::tuple& t) + : IValue(std::apply(c10::ivalue::Tuple::create, t)) { +} + +template < + typename... Args, + std::enable_if_t< + !std::disjunction_v< + std::is_lvalue_reference..., + std::negation>...>, + std::nullptr_t>> +inline IValue::IValue(std::tuple&& t) + : IValue(std::apply(c10::ivalue::Tuple::create, std::move(t))) { +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::String) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} +inline IValue::IValue(std::string v) + : IValue(ivalue::ConstantString::create(std::move(v))) {} + +inline IValue::IValue(c10::impl::GenericList v) + : tag(Tag::GenericList) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); +} + +template > +inline IValue::IValue(c10::List&& v) : IValue(impl::toList(std::move(v))) {} +template > +inline IValue::IValue(const c10::List& v) : IValue(impl::toList(v)) {} +template > +inline IValue::IValue(at::ArrayRef v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + for (const auto& e : v) { + list.push_back(e); + } +} +template > +inline IValue::IValue(at::ArrayRef v) : IValue() { + auto vi = c10::asIntArrayRefSlowOpt(v); + if (vi.has_value()) { + // This list is entirely integers; ensure it is typed as + // an IntList so toIntList works + *this = IValue(*vi); + } else { + // This list has SymInts; type it as a SymInt + *this = IValue(impl::toList(c10::List())); + auto list = to>(); + list.reserve(v.size()); + for (const auto& e : v) { + list.push_back(e); + } + } +} +template > +inline IValue::IValue(at::OptionalArrayRef mb_v) : IValue() { + if (!mb_v.has_value()) return; + *this = IValue(*mb_v); +} +template > +inline IValue::IValue(const std::vector& v) : IValue() { + *this = IValue(at::ArrayRef(v)); +} +template > +inline IValue::IValue(std::vector&& v) : IValue() { + auto vi = c10::asIntArrayRefSlowOpt(v); + if (vi.has_value()) { + // This list is entirely integers; ensure it is typed as + // an IntList so toIntList works + *this = IValue(*vi); + } else { + // This list has SymInts; type it as a SymInt + *this = IValue(impl::toList(c10::List())); + auto list = to>(); + list.reserve(v.size()); + for (auto&& e : std::move(v)) { + list.push_back(std::move(e)); + } + } +} +template > +inline IValue::IValue(const std::vector& v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + for (const auto& e : v) { + list.push_back(e); + } +} + +template > +inline IValue::IValue(std::vector&& v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + if constexpr (std::is_same_v) { + for (auto e : v) { + list.push_back(e); + } + } else { + for (auto&& e : std::move(v)) { + list.push_back(std::move(e)); + } + } +} + +template > +inline IValue::IValue(c10::OptionalArrayRef v) : IValue() { + if (v.has_value()) { + *this = IValue(std::move(*v)); + } +} + +template +inline IValue::IValue(std::array v) : IValue(c10::List()) { + auto list = to>(); + list.reserve(v.size()); + for (auto& e : v) { + list.push_back(std::move(e)); + } +} + +template > +inline IValue::IValue(c10::IListRef v) : IValue() { + constexpr bool boxed_type_constructs_ivalue = + std::is_constructible_v::boxed_type>; + // First, we try to use the boxed value. + // If we fail (either it's not in the boxed state, or its boxed type + // can not construct an IValue), we fallback to copying the list. + if (boxed_type_constructs_ivalue && v.isBoxed()) { + *this = IValue(impl::toList(v.toBoxed())); + } else { + c10::List list; + list.reserve(v.size()); + for (const auto& t : v) { + list.push_back(t); + } + *this = IValue(impl::toList(std::move(list))); + } +} + +inline IValue::IValue(c10::impl::GenericDict v) + : tag(Tag::GenericDict) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); +} +template +inline IValue::IValue(c10::Dict v) + : IValue(impl::toGenericDict(std::move(v))) {} + +template +inline IValue::IValue(std::unordered_map v) + : IValue(Dict()) { + auto dict = to>(); + dict.reserve(v.size()); + for (auto& e : v) { + dict.insert(std::move(e.first), std::move(e.second)); + } +} + +template > +inline IValue::IValue(std::optional v) : IValue() { + if (v.has_value()) { + *this = IValue(std::move(*v)); + } +} + +inline IValue::IValue(std::nullopt_t) : IValue() {} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Object) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::PyObject) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Enum) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue IValue::make_capsule( + intrusive_ptr blob) { + IValue iv; + iv.tag = Tag::Capsule; + iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); + return iv; +} + +template < + typename T, + std::enable_if_t, int>> +IValue::IValue(c10::intrusive_ptr custom_class) : tag(Tag::Object) { + auto classType = []() { + try { + return c10::getCustomClassType>(); + } catch (const c10::Error&) { + throw c10::Error( + "Trying to instantiate a class that isn't a registered custom class: " + + std::string(c10::util::get_fully_qualified_type_name())); + } + }(); + auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1); + ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release()); + +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Future) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Await) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::RRef) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +inline IValue::IValue(c10::intrusive_ptr v) + : tag(Tag::Quantizer) { + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); +} + +template +inline IValue::IValue(c10::complex c) + : tag(Tag::ComplexDouble) { + auto v = c10::make_intrusive(c); + payload.u.as_intrusive_ptr = v.release(); +} + +inline const std::string& IValue::toStringRef() const { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringRef on null intrusive_ptr IValue"); + return static_cast( + payload.u.as_intrusive_ptr) + ->string(); +} +inline std::optional> IValue:: + toOptionalStringRef() const { + if (isNone()) { + return std::nullopt; + } + AT_ASSERT(isString(), "Expected std::optional but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toOptionalStringRef on null intrusive_ptr IValue"); + return std::reference_wrapper( + static_cast(payload.u.as_intrusive_ptr) + ->string()); +} + +inline std::string_view IValue::toStringView() const { + AT_ASSERT(isString(), "Expected String but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringView on null intrusive_ptr IValue"); + return static_cast( + payload.u.as_intrusive_ptr) + ->string_view(); +} + +inline PyObject* IValue::toPyObject() const { + return toPyObjectHolder()->getPyObject(); +} + +template +inline std::optional IValue::toOptional() { + if (this->isNone()) { + return std::nullopt; + } + return this->to(); +} + +template +inline std::optional IValue::toOptional() const { + if (this->isNone()) { + return std::nullopt; + } + return this->to(); +} + +inline bool IValue::isCustomClass() const { + return torch::isCustomClass(*this); +} + +inline bool IValue::isSameIdentity(const IValue& rhs) const { + // We choose to not use memcmp for payload check due to potential random + // padding characters on union type + + // Semantics: + // 1. Immutable primitive values of the same type (Int, Double, None, Bool, + // Str) return value equality + // 2. If it is a tensor type, we need to take undefined tensor into account + // 3. Undefined_tensor is None and vice versa should be true + // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when + // the pointed-to object is the same. + // 5. False for all other comparisons. + if (this->isNone() && rhs.isNone()) { + return true; + } else if (this->isBool() && rhs.isBool()) { + // for bool type, do equality check + return this->toBool() == rhs.toBool(); + } else if (this->isTensor() && rhs.isTensor()) { + return this->payload.as_tensor.is_same(rhs.payload.as_tensor); + } else if (this->isTensor() && rhs.isNone()) { + // special case: undefined tensor and None are the same identity + return !this->payload.as_tensor.defined(); + } else if (this->isNone() && rhs.isTensor()) { + // special case: undefined tensor and None are the same identity + return !rhs.payload.as_tensor.defined(); + } else if (this->isInt() && rhs.isInt()) { + return this->toInt() == rhs.toInt(); + } else if (this->isDouble() && rhs.isDouble()) { + return this->toDouble() == rhs.toDouble(); + } else if (this->isString() && rhs.isString()) { + return this->toStringRef() == rhs.toStringRef(); + } else { + // for objects holding in IValue, do shallow compare on pointer address to + // testify the identity + return this->isIntrusivePtr() && rhs.isIntrusivePtr() && + this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; + } +} + +namespace ivalue { +namespace detail { + +template +IValue from_(T&& x, std::true_type) { + return IValue(std::forward(x)); +} +template +IValue from_(c10::intrusive_ptr x, std::false_type) { + return IValue(std::move(x)); +} +template +IValue from_(T&& /*x*/, std::false_type) { + static_assert( + guts::false_t::value, + "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)"); + return IValue(); +} +} // namespace detail + +template +IValue from(T&& x) { + return detail::from_( + std::forward(x), typename std::is_constructible::type{}); +} + +} // namespace ivalue + + +template <> +struct MaybeOwnedTraits { + using owned_type = IValue; + using borrow_type = IValue; + + static borrow_type createBorrow(const owned_type& from) { + if (!from.isPtrType()) { + return from; + } + if (from.isTensor()) { + return IValue(MaybeOwnedTraits::createBorrow(from.toTensor())); + } else { + return IValue(from.payload, from.tag); + } + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.clearToNone(); + if (!rhs.isPtrType()) { + lhs = rhs; + } else if (rhs.isTensor()) { + lhs = IValue(MaybeOwnedTraits::createBorrow(rhs.toTensor())); + } else { + lhs = IValue(rhs.payload, rhs.tag); + } + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.clearToNone(); + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type&) { + return true; + } +}; + +template <> +struct IValue::TagType { + static TORCH_API c10::TypePtr get(const IValue&); +}; + +template <> +struct IValue::TagType { + static TORCH_API c10::TypePtr get(const IValue&); +}; + +template +TypePtr IValue::type() const { + return IValue::TagType::get(*this); +} + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue_to.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue_to.h new file mode 100644 index 0000000000000000000000000000000000000000..52af58083596939106fe120abfbd360f0068db67 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/ivalue_to.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace at { +class Tensor; +} // namespace at + +namespace c10 { +struct IValue; +namespace detail { +// Determine the return type of `IValue::to() const &`. It's a const +// reference when possible and a copy otherwise. It is in this +// separate header so that List can use it as well. +template +struct ivalue_to_const_ref_overload_return { + using type = T; +}; + +template<> +struct ivalue_to_const_ref_overload_return { + using type = const at::Tensor&; +}; + +template<> +struct ivalue_to_const_ref_overload_return { + using type = const std::string&; +}; + +template<> +struct ivalue_to_const_ref_overload_return { + using type = const IValue&; +}; + +} // namespace detail +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/jit_type.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/jit_type.h new file mode 100644 index 0000000000000000000000000000000000000000..c15e5f72af27c1b2e1ee7d9725e1b5cd60c340c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/jit_type.h @@ -0,0 +1,2437 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + + +namespace torch::jit { +struct Function; +} // namespace torch::jit + + +namespace c10 { + +template +class Dict; +struct IValue; +struct FunctionSchema; +struct NamedType; +using OptNameList = std::optional>; + +void standardizeVectorForUnion(std::vector& reference, std::vector* to_fill); +void standardizeVectorForUnion(std::vector* to_flatten); + +inline bool is_contiguous_strides( + const IntArrayRef sizes, + const IntArrayRef strides) { + int n_dim = static_cast(sizes.size()); + if (n_dim == 0) { + return true; + } + + if (strides[n_dim - 1] != 1) { + return false; + } + + for (int i = n_dim - 2; i >= 0; i--) { + if (strides[i] != strides[i + 1] * sizes[i + 1]) { + return false; + } + } + return true; +} + +struct AnyType; +using AnyTypePtr = SingletonTypePtr; +// Any is the top of the type hierarchy, all other types are subtypes +// T <: Any, forall T +struct TORCH_API AnyType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Any"; + } + static const TypeKind Kind = TypeKind::AnyType; + // global singleton + static AnyTypePtr get(); + + private: + AnyType() : Type(TypeKind::AnyType) {} +}; + +inline std::string toString(const Type& type) { + return type.str(); +} + +// Shim for compatibility with code that uses TypePtr. +inline std::string toString(const TypePtr& typePtr) { + return toString(*typePtr); +} + +inline bool operator!=(const Type& lhs, const Type& rhs) { + return !(lhs == rhs); +} + +// common base for all types that have a single sub element +// e.g. Future[T], Optional[T], List[T] +template +struct SingleElementType : public SharedType { + static const TypeKind Kind = K; + + const TypePtr& getElementType() const { + return elem; + } + + bool hasFreeVariables() const override { + return getElementType()->hasFreeVariables(); + } + + at::ArrayRef containedTypes() const override { + return elem; + } + + bool equals(const Type& rhs) const override { + if (auto rhs_ = rhs.cast()) { + return *getElementType() == *rhs_->getElementType(); + } + return false; + } + + protected: + SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { + if (!this->elem) { + throw std::runtime_error(c10::str( + "Can not create ", typeKindToString(Kind), " with None type")); + } + } + + private: + TypePtr elem; +}; + +struct UnionType; +using UnionTypePtr = std::shared_ptr; +struct TORCH_API UnionType : public SharedType { + friend struct Type; + + static const TypeKind Kind = TypeKind::UnionType; + + bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; + + std::string str() const override; + + static UnionTypePtr create(std::vector reference); + + bool equals(const Type& rhs) const override; + + bool isUnionType() const override { + return true; + } + + at::ArrayRef containedTypes() const override { + return types_; + } + + // For testing purposes only + at::ArrayRef getTypes() const { + return types_; + } + + TypePtr createWithContained(std::vector contained_types) const override { + return create(std::move(contained_types)); + } + + bool canHoldType(const Type& type) const; + + bool hasFreeVariables() const override { + return has_free_variables_; + } + + std::optional toOptional() const; + + std::optional subtractTypeSet(std::vector& to_subtract) const; + + protected: + explicit UnionType(std::vector types, TypeKind kind=TypeKind::UnionType); + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override; + std::string unionStr( + const TypePrinter& printer = nullptr, + bool is_annotation_str = false) const; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool has_free_variables_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector types_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool can_hold_none_; + +}; + +struct OptionalType; +using OptionalTypePtr = std::shared_ptr; +// This type represents an optional type. There is one `Optional` for +// each element type. `Optional[T]` can accept both `T` and +// `None`(`std::nullopt` in C++) +// Subtype hierarchy for Optional: +// - Optional[T] <: Optional[R] iff T <: R +// - T <: Optional[R] if T <: R +// - None <: Optional[T] for all T +// - Optional[T] == Union[T, None] for all T +struct TORCH_API OptionalType : public UnionType { + static OptionalTypePtr create(const TypePtr& contained); + + static const TypeKind Kind = TypeKind::OptionalType; + + friend struct Type; + + bool equals(const Type& rhs) const override; + + const TypePtr& getElementType() const { + return contained_; + } + + at::ArrayRef containedTypes() const override { + return contained_; + } + + std::string str() const override { + std::stringstream ss; + ss << getElementType()->str() << "?"; + return ss.str(); + } + + TypePtr createWithContained( + std::vector contained_types) const override { + AT_ASSERT(contained_types.size() == 1); + return create(contained_types[0]); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + bool isUnionType() const override { + return true; + } + + // common cast Optional[Tensor] for undefined tensor type + static TypePtr ofTensor(); + // + // global singleton + static TypePtr get(TypePtr inner); + + private: + explicit OptionalType(const TypePtr& contained); + + TypePtr contained_; + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "Optional[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +template +inline std::optional merge_primitive( + const std::optional& a, + const std::optional& b) { + if (a.has_value() && b.has_value() && a.value() == b.value()) { + return a; + } + return std::optional{}; +} + +// If we see `a + b + c` and know that a, b, and c are the same size and have +// two dimensions (WxH), then we can generate a fused kernel for them. That +// fused kernel would likely have indexing math to handling both the W and H +// dimensions. However, if we knew the WxH dimensions were contiguous, we can +// pretend like we only have a single dimension, simplifying the indexing logic. +// This can be performed even if the dimensions are transposed, +// as long as a, b, and c are transposed in the same way. +// We'd like to have the compiler be able to do this dimensionality reduction, +// but simply knowing sizes is not enough. +// We can extend profiling to also record stride information. +// Rather than recording specific strides, +// we can simply order the strides from smallest to largest with +// `stride_indices` A contiguity marker on the smallest stride (c0) indicates +// the stride is precisely 1, otherwise a contiguity marker means that $stride_n +// = size_{n-1}*stride_{n-1}$ +struct TORCH_API Stride { + Stride() = default; + Stride( + const std::optional& stride_index, + std::optional contiguous, + const std::optional& stride) + : stride_index_(stride_index), contiguous_(contiguous), stride_(stride) {} + + bool operator==(const Stride& b) const { + return stride_index_ == b.stride_index_ && contiguous_ == b.contiguous_ && + stride_ == b.stride_; + } + + bool isComplete() const { + return stride_index_ && contiguous_ && stride_; + } + + std::optional stride_index_; + std::optional contiguous_; + std::optional stride_; +}; + +template <> +inline std::optional merge_primitive( + const std::optional& a, + const std::optional& b) { + std::optional left = a; + std::optional right = b; + if (!left.has_value()) { + left = {Stride()}; + } + if (!right.has_value()) { + right = {Stride()}; + } + + auto merged_index = + merge_primitive(left->stride_index_, right->stride_index_); + auto merged_cont = merge_primitive(left->contiguous_, right->contiguous_); + auto merged_stride = merge_primitive(left->stride_, right->stride_); + auto r = Stride(merged_index, merged_cont, merged_stride); + // normalize + if (!r.stride_index_.has_value() && !r.contiguous_.has_value() && + !r.stride_.has_value()) { + return std::optional{}; + } + + return r; +} + +struct TORCH_API ShapeSymbol { + // needed for use in `std::map` + ShapeSymbol() : value_(-1) {} + // is this symbol a fixed/static dimension + bool is_static() const { + return value_ >= 0; + } + bool operator==(const ShapeSymbol& b) const { + return value_ == b.value_; + } + bool operator<(const ShapeSymbol& b) const { + return value_ < b.value_; + } + + static ShapeSymbol fromStaticSize(int64_t val) { + return ShapeSymbol(val); + } + int64_t static_size() const { + TORCH_CHECK(is_static()); + return value_; + } + + int64_t value() const { + return value_; + } + + static ShapeSymbol newSymbol() { + return fromStaticSize(-static_cast(++num_symbols)); + } + friend TORCH_API std::ostream& operator<<( + std::ostream& os, + const ShapeSymbol& s); + + private: + ShapeSymbol(int64_t val) : value_(val) {} + int64_t value_; + static std::atomic num_symbols; +}; + +inline ShapeSymbol merge_primitive( + const ShapeSymbol& a, + const ShapeSymbol& b) { + if (a.is_static() && b.is_static() && a == b) { + return a; + } + return ShapeSymbol::newSymbol(); +} + +// Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown +// dims, partially known and fully known shapes are all supported. +struct TORCH_API SymbolicShape { + // Unranked shape constructor. + SymbolicShape() : dims_(std::nullopt) {} + + // Known rank but unknown dimentions. + SymbolicShape(std::optional rank) : dims_(std::nullopt) { + if(!rank) { + return; + } + + std::vector shape_symbols; + shape_symbols.reserve(*rank); + for(size_t i = 0; i < *rank; ++i) { + shape_symbols.push_back(ShapeSymbol::newSymbol()); + } + dims_ = shape_symbols; + } + + // Mix of known and unknown ranks + SymbolicShape(const std::vector>& dims) { + std::vector shape_symbols; + shape_symbols.reserve(dims.size()); + for(std::optional dim: dims) { + if(!dim) { + shape_symbols.push_back(ShapeSymbol::newSymbol()); + } else { + shape_symbols.push_back(ShapeSymbol::fromStaticSize(*dim)); + } + } + dims_ = shape_symbols; + } + + void dump() const; + + SymbolicShape(std::vector dims) : dims_(std::move(dims)) {} + + SymbolicShape(c10::IntArrayRef dims) { + std::vector shape_symbols; + shape_symbols.reserve(dims.size()); + for(int64_t dim : dims) { + shape_symbols.push_back(ShapeSymbol::fromStaticSize(dim)); + } + dims_ = shape_symbols; + } + + ShapeSymbol operator[](size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + + ShapeSymbol at(size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + + // Returns rank or nullopt in case of unranked shape. + std::optional rank() const { + if(!dims_) { + return std::nullopt; + } + return dims_->size(); + } + + std::optional> sizes() const { + return dims_; + } + + std::optional> symbolicDims() const { + if (!dims_) { + return std::nullopt; + } + auto symbolic_dims = std::vector(); + for (const ShapeSymbol& s : *dims_) { + symbolic_dims.push_back(!s.is_static()); + } + return symbolic_dims; + } + + // Checks whether the shape is fully defined/complete, ie. rank and sizes + // of every dimension are known. + bool isComplete() const { + if(!dims_) { + return false; + } + for(auto d : *dims_) { + if(!d.is_static()) { + return false; + } + } + return true; + } + + // Create new SymbolicShape that is result of merging self and another + // SymbolicShape. Only dimensions that are static and equal will be + // preserved. + // If either of two shapes are of unknown rank or they have unmatching rank, + // result will be unranked. + SymbolicShape merge(const SymbolicShape& other) const; + + friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) { + return lhs.dims_ == rhs.dims_; + } + + friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) { + return !(lhs == rhs); + } + + private: + std::optional> dims_; +}; + +namespace detail { +inline bool isComplete(const Stride& s) { + return s.isComplete(); +} + +template +inline bool isComplete(const T& /*t*/) { + return true; +} +} + +template +struct VaryingShape { + using ListOfOptionalElements = std::vector>; + VaryingShape(const std::vector& vec) + : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {} + + VaryingShape(c10::ArrayRef vec) + : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {} + + VaryingShape(std::optional size = std::nullopt) : dims_(std::nullopt) { + if (size) { + dims_ = ListOfOptionalElements(*size); + } + } + + VaryingShape(ListOfOptionalElements dims) : dims_(std::move(dims)) {} + + VaryingShape(size_t size) : VaryingShape(std::optional(size)) {} + + bool operator==(const VaryingShape& other) const { + return dims_ == other.dims_; + } + + const std::optional &operator[](size_t i) const { + if (!dims_) { + throw std::runtime_error("Rank isn't fixed"); + } + return (*dims_).at(i); + } + + std::optional size() const { + if (!dims_) { + return std::nullopt; + } + const auto& dims = dims_.value(); + return dims.size(); + } + + const std::optional& sizes() const { + return dims_; + } + + TORCH_API VaryingShape merge(const VaryingShape& other) const; + + std::optional> concrete_sizes() const { + if (!dims_) { + return std::nullopt; + } + std::vector sizes; + sizes.reserve(dims_.value().size()); + for (auto d : *dims_) { + if (!d) { + return std::nullopt; + } + sizes.push_back(d.value()); + } + return sizes; + } + + bool isComplete() const { + if (!dims_) { + return false; + } + for (auto d : *dims_) { + if (!d || !detail::isComplete(*d)) { + return false; + } + } + return true; + } + + private: + std::optional dims_; +}; + +struct TensorType; +// TODO: investigate making this SingletonOrSharedTypePtr +using TensorTypePtr = std::shared_ptr; +// This type represents a single Tensor with a specific size +struct TORCH_API TensorType : public SharedType { + static TensorTypePtr create(const at::Tensor& t); + + // used by TensorType::create(size_t dim) which in turn used by + // shape_analysis.cpp + static TensorTypePtr create( + std::optional scalar_type, + std::optional device, + const VaryingShape& sizes, + const VaryingShape& strides, + std::optional requires_grad, + std::optional undefined = false, + bool tensor_contiguity = false); + + static TensorTypePtr create( + std::optional scalar_type, + std::optional device, + SymbolicShape sizes, + VaryingShape stride_, + std::optional requires_grad, + std::optional undefined = false); + + static TensorTypePtr create( + std::optional scalar_type, + std::optional device, + std::optional dim, + std::optional requires_grad); + + // overloaded create variadic template argument as it could not distinguish + // initializer list + static TensorTypePtr createContiguous( + at::ScalarType scalar_type, + at::Device device, + at::IntArrayRef sizes); + + static TypePtr fromNumberType(const Type& typ); + static TypePtr fromBoolType(); + + std::optional dim() const { + return sizes().size(); + } + + VaryingShape sizes() const; + + VaryingShape strides() const; + + const VaryingShape& stride_properties() const { + return strides_; + } + + const std::optional& device() const { + return device_; + } + const std::optional& scalarType() const { + return scalar_type_; + } + const std::optional& requiresGrad() const { + return requires_grad_; + } + bool requires_grad() const override { + return requires_grad_ ? *requires_grad_ : true; + } + + bool equals(const Type& rhs) const override; + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + std::string str() const override; + + std::string repr_str() const override { + if (isInferredType()) { + return str() + " (inferred)"; + } else { + return str(); + } + } + + std::optional numel() const { + size_t prod = 1; + const auto& shape = sizes(); + + for (size_t i = 0; i < shape.size(); i++) { + auto const &s = shape[i]; + if (!s.has_value()) { + return std::optional{}; + } + prod *= s.value(); + } + return prod; + } + + TensorTypePtr withRequiresGrad(std::optional s) { + auto copy = clone(); + copy->requires_grad_ = s; + return copy; + } + + TensorTypePtr withScalarType(std::optional st) { + auto copy = clone(); + copy->scalar_type_ = st; + return copy; + } + + TensorTypePtr withDim(std::optional d) { + auto copy = clone(); + // withDim is only used by the legacy executor + // that only cares about the rank, so create dummy symbols)) : + copy->sizes_ = SymbolicShape(d); + copy->strides_ = VaryingShape(d); + return copy; + } + + TensorTypePtr withStrides(VaryingShape sstrides) const { + auto cloned = clone(); + cloned->strides_ = std::move(sstrides); + return cloned; + } + + TensorTypePtr withSizesStrides( + at::IntArrayRef sizes, + at::IntArrayRef strides) const { + auto cloned = clone(); + auto ssizes = SymbolicShape(sizes); + cloned->sizes_ = ssizes; + cloned->strides_ = computeStrideProps(sizes, strides); + return cloned; + } + + TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const { + auto cloned = clone(); + cloned->sizes_ = std::move(ssizes); + return cloned; + } + + TensorTypePtr withSizes(at::IntArrayRef sizes) const { + return withSizesStrides( + sizes, contiguousStridesOf(sizes)); + } + + TensorTypePtr withDevice(const std::optional device) const { + auto copy = clone(); + copy->device_ = device; + return copy; + } + + TensorTypePtr dimensionedOnly() const { + auto copy = clone(); + copy->sizes_ = SymbolicShape(sizes().size()); + copy->strides_ = VaryingShape(sizes().size()); + return copy; + } + + TensorTypePtr contiguous() const { + auto cloned = clone(); + auto concrete_sizes = sizes().concrete_sizes(); + TORCH_INTERNAL_ASSERT(concrete_sizes.has_value()); + auto strides = computeStrideProps( + *concrete_sizes, + contiguousStridesOf(*concrete_sizes)); + cloned->strides_ = strides; + return cloned; + } + + const SymbolicShape& symbolic_sizes() const; + + TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const; + + bool matchTensor(const at::Tensor& t); + + // is all information about the type specified except for autograd? + // This replaces the notion of a 'CompleteTensorType' that used to exist + // in the type-hierarchy. Excluding require_grad and undefined allows + // this to match the old behavior. + bool isComplete() const { + return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete(); + } + + bool isInferredType() const { + return is_inferred_; + } + + static TensorTypePtr getInferred() { + static auto valueInferred = TensorType::create( + /*scalar_type=*/{}, + /*device=*/{}, + /*sizes=*/SymbolicShape(), + /*stride=*/VaryingShape{}, + /*requires_grad=*/{}, + /*undefined=*/false); + valueInferred->is_inferred_ = true; + return valueInferred; + } + + // this property is used by GuardElimination + // please see `checkInputs` for more details + bool isSummarized() const { + return !(isComplete() && requiresGrad().has_value() && + undefined().has_value()); + } + + TensorTypePtr withUndefined() { + auto r = clone(); + r->undefined_ = true; + return r; + } + + TensorTypePtr withPossiblyUndefined() { + auto r = clone(); + r->undefined_ = std::nullopt; + return r; + } + + std::optional undefined() const { return undefined_; } + + static const TensorTypePtr& get(); + + static const TypeKind Kind = TypeKind::TensorType; + + static std::vector contiguousStridesOf( + at::IntArrayRef in_sizes, + at::MemoryFormat memory_format = MemoryFormat::Contiguous) { + auto contiguous_fn = [](const at::IntArrayRef& sizes, + const std::vector& dim_order) { + std::vector strides(sizes.size()); + if (sizes.empty()) // zero-dim case + return strides; + + strides[dim_order[0]] = 1; + for (size_t i = 1; i < dim_order.size(); i++) { + auto cur_dim = dim_order[i]; + auto pre_dim = dim_order[i - 1]; + strides[cur_dim] = strides[pre_dim] * sizes[pre_dim]; + } + return strides; + }; + + std::vector dim_order(in_sizes.size()); + if (memory_format == MemoryFormat::ChannelsLast) { + dim_order = {1, 3, 2, 0}; + } else if (memory_format == MemoryFormat::ChannelsLast3d) { + dim_order = {1, 4, 3, 2, 0}; + } else { + auto ndims = in_sizes.size(); + for (size_t i = 0; i < ndims; i++) { + dim_order[i] = static_cast(ndims - i - 1); // Reverse + } + } + return contiguous_fn(in_sizes, dim_order); + } + + private: + TensorType( + std::optional scalar_type, + std::optional device, + SymbolicShape sizes, + VaryingShape strides, + std::optional requires_grad, + std::optional undefined = false); + + TensorTypePtr clone() const { + return TensorTypePtr(new TensorType( + scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_)); + } + + static VaryingShape computeStrideProps( + at::IntArrayRef sizes, + at::IntArrayRef strides, + bool tensor_contiguity = false); + + std::optional scalar_type_; + std::optional device_; + SymbolicShape sizes_; + VaryingShape strides_; + std::optional requires_grad_; + // we exploit the fact certain tensors must be zero in the autograd to + // optimize gradient computation. Such zero tensors are currently implemented + // with `UndefinedTensorImpl.` They can be handled only by special operators + // (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false. + // Normally, `undefined_` is set to false, unless a type was created + // with `withUndefined` + // This will also mean that `undefined` tensors will fail + // `subtypeOf(TensorType::get())` check + // undefined_ may become `std::nullopt` if the tensor was observed to be both + // defined and undefined. However, no tensor type starts out with + // `undefined_` set to `std::nullopt` + std::optional undefined_; + // Represents whether or not this type was inferred. + bool is_inferred_ = false; +}; + +struct ListType; +using ListTypePtr = std::shared_ptr; +struct TORCH_API ListType + : public SingleElementType { + // It's not exactly a singleton, but there should be exactly one instance of + // List[T] for every T + friend struct Type; + template + static ListTypePtr create(T&&... all) { + return ListTypePtr( + new ListType(std::forward(all)...)); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << getElementType()->str() << "[]"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + // global singleton + // Given an inner type T and an identifier, + // this function wil return the global singleton type pointer + // the type List. + // The extra "identifier" argument is needed beccause we have multiple container types + // that all re-use this function (List, array, etc.) + static TypePtr get(const std::string& identifier, TypePtr inner); + + // common cast List[Tensor] + static ListTypePtr ofTensors(); + static ListTypePtr ofOptionalTensors(); + static ListTypePtr ofInts(); + static ListTypePtr ofSymInts(); + static ListTypePtr ofFloats(); + static ListTypePtr ofComplexDoubles(); + static ListTypePtr ofBools(); + static ListTypePtr ofStrings(); + static ListTypePtr ofNumbers(); + + private: + ListType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "List[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +struct DictType; +using DictTypePtr = std::shared_ptr; +struct TORCH_API DictType : public SharedType { + friend struct Type; + static const TypeKind Kind = TypeKind::DictType; + + static DictTypePtr create(TypePtr key, TypePtr value) { + auto kind = key->kind(); + if (auto dyn = key->castRaw()) { + kind = dyn->dynamicKind(); + } + switch (kind) { + case TypeKind::AnyType: + case TypeKind::IntType: + case TypeKind::BoolType: + case TypeKind::FloatType: + case TypeKind::ComplexType: + case TypeKind::StringType: + case TypeKind::TensorType: + case TypeKind::DeviceObjType: + return DictTypePtr(new DictType(std::move(key), std::move(value))); + default: + TORCH_CHECK(false, + "Cannot create dict for key type '", + key->str(), + "', only int, float, complex, Tensor, device and string keys are supported"); + } + } + + // aligned with the format in FunctionSchema + std::string str() const override { + std::stringstream ss; + ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str() + << ")"; + return ss.str(); + } + + TypePtr createWithContained( + std::vector contained_types) const override { + if (contained_types.size() != 2) { + throw std::runtime_error("Expected 2 contained types"); + } + return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); + } + + const TypePtr& getKeyType() const { + return types.at(0); + } + + const TypePtr& getValueType() const { + return types.at(1); + } + + bool hasFreeVariables() const override { + return has_free_variables; + } + + at::ArrayRef containedTypes() const override { + return types; + } + + bool equals(const Type& rhs) const override { + if (auto* dict_rhs = rhs.castRaw()) { + return *getKeyType() == *(dict_rhs->getKeyType()) && + *getValueType() == *(dict_rhs->getValueType()); + } + return false; + } + + // global singleton + // Given an inner type T and an identifier, + // this function will return the global singleton type pointer + // the type List. + // The extra "identifier" argument is needed because we have multiple container types + // that all re-use this function (Dict and unordered_map) + static TypePtr get(const std::string& identifier, TypePtr key, TypePtr val); + + private: + DictType(TypePtr key, TypePtr value) + : SharedType(TypeKind::DictType), + has_free_variables( + key->hasFreeVariables() || value->hasFreeVariables()) { + types.reserve(2); + types.push_back(std::move(key)); + types.push_back(std::move(value)); + } + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override; + + std::vector types; + bool has_free_variables; +}; + +struct FutureType; +using FutureTypePtr = std::shared_ptr; + +struct TORCH_API FutureType + : public SingleElementType { + friend struct Type; + template + static FutureTypePtr create(TypePtr elem) { + return FutureTypePtr( + new FutureType(std::move(elem))); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << "Future(" << getElementType()->str() << ")"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + if (Type::isSubtypeOfExt(rhs, why_not)) { + return true; + } + if (auto rhs_ = rhs.castRaw()) { + return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not); + } + return false; + } + + private: + FutureType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "Future[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +struct AwaitType; +using AwaitTypePtr = std::shared_ptr; + +struct TORCH_API AwaitType + : public SingleElementType { + friend struct Type; + template + static AwaitTypePtr create(TypePtr elem) { + return AwaitTypePtr( + new AwaitType(std::move(elem))); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << "Await(" << getElementType()->str() << ")"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + if (Type::isSubtypeOfExt(rhs, why_not)) { + return true; + } + if (auto rhs_ = rhs.castRaw()) { + return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not); + } + return false; + } + + private: + AwaitType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "Await[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +struct RRefType; +using RRefTypePtr = std::shared_ptr; + +struct TORCH_API RRefType + : public SingleElementType { + friend struct Type; + template + static RRefTypePtr create(TypePtr elem) { + return RRefTypePtr( + new RRefType(std::move(elem))); // NOLINT(modernize-make-shared) + } + + std::string str() const override { + std::stringstream ss; + ss << "RRef(" << getElementType()->str() << ")"; + return ss.str(); + } + TypePtr createWithContained( + std::vector contained_types) const override { + return create(std::move(contained_types.at(0))); + } + + private: + RRefType(TypePtr elem) : SingleElementType(std::move(elem)) {} + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override { + std::stringstream ss; + ss << "RRef[" << getElementType()->annotation_str(printer) << "]"; + return ss.str(); + } +}; + +// Any should never appear in a named type like a class, namedtuple or +// interface. If it does, then dynamic type information will be lost in the +// Pickler, leading to hard-to-track-down bugs that will only occur +// after saving or loading a model. This is because we rely on the +// static types in named types to reconstruct type tags of loaded +// values. Lifting this restriction requires solving the serialization +// problem first. +TORCH_API void checkNoAny( + const Type& base, + const char* what, + const std::string& attrname, + const TypePtr& attrtype); + +struct TupleType; +using TupleTypePtr = std::shared_ptr; +using NameList = std::vector; +// This type represents a Tuple +struct TORCH_API TupleType : public NamedType { + + static TupleTypePtr createNamed(const std::optional& name, + const std::vector& field_names, + const std::vector& field_types, + std::vector& field_defaults); + + static TupleTypePtr createNamed(const std::optional& name, + const std::vector& field_names, + const std::vector& field_types); + + static TupleTypePtr createNamed(const std::optional& name, + const std::vector& field_names, + const std::vector& field_types); + + static TupleTypePtr create( + std::vector types) { + return TupleTypePtr(new TupleType( + std::move(types), + std::nullopt, + nullptr)); // NOLINT(modernize-make-shared) + } + static TupleTypePtr create() { + return create({}); + } + + at::ArrayRef elements() const { + return elements_; + } + + bool equals(const Type& rhs) const override; + bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; + + std::string str() const override; + bool hasFreeVariables() const override { + return has_free_variables_; + } + at::ArrayRef containedTypes() const override { + return elements_; + } + TypePtr createWithContained( + std::vector contained_types) const override { + return std::shared_ptr( + new TupleType(std::move(contained_types), name(), schema())); + } + const std::shared_ptr& schema() const { + return schema_; + } + std::optional> names() const; + + static const TypeKind Kind = TypeKind::TupleType; + + private: + template + static TupleTypePtr createWithSpec( + const std::optional& name, + const std::vector& field_names, + const std::vector& field_types, + std::vector& field_defaults); + + TupleType( + std::vector elements_, + std::optional name, + std::shared_ptr schema); + + bool compare( + const Type& rhs, + const std::function& fn) const { + if (rhs.kind() != kind()) { + return false; + } + + const auto& l_elements = elements(); + const auto& r_elements = rhs.castRaw()->elements(); + if (l_elements.size() != r_elements.size()) + return false; + for (size_t i = 0; i < l_elements.size(); ++i) { + if (!fn(*l_elements[i], *r_elements[i])) + return false; + } + return true; + } + + std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override; + + std::vector elements_; + bool has_free_variables_; + std::shared_ptr schema_; +}; + +// the common supertype of all Enums, only used in operator registraion. +// EnumType <: AnyEnumType for all Enums +struct AnyEnumType; +using AnyEnumTypePtr = SingletonTypePtr; +struct TORCH_API AnyEnumType final : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "AnyEnumType"; + } + static const TypeKind Kind = TypeKind::AnyEnumType; + // global singleton + static AnyEnumTypePtr get(); +private: + AnyEnumType() + : Type(TypeKind::AnyEnumType) {} +}; + +struct NumberType; +using NumberTypePtr = SingletonTypePtr; +// This type represents a Python number +// Subtype hierarchy for Number Types (NumberType as the base type): +// IntType <: NumberType +// FloatType <: NumberType +// ComplexType <:NumberType +// +// WARNING: if you add a new subtype of NumberType that is not +// represented by a global singleton, you need to change NumberTypePtr +// to a SingletonOrSharedTypePtr and deal with NumberType needing to +// both inherit and not inherit from SharedType! +struct TORCH_API NumberType : public Type { + bool equals(const Type& rhs) const override; + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + std::string str() const override { + return "Scalar"; // match what PythonArgParser says for clarity + } + static const TypeKind Kind = TypeKind::NumberType; + // global singleton + static NumberTypePtr get(); + + protected: + NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {} + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "number"; // technically not a valid python type, but + // we need to use it when parsing back in annotations + // for implicit conversions + } +}; + +struct FloatType; +using FloatTypePtr = SingletonTypePtr; +// This type represents a Python float number +struct TORCH_API FloatType : public NumberType { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "float"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + } + static const TypeKind Kind = TypeKind::FloatType; + // global singleton + static FloatTypePtr get(); + + private: + FloatType() : NumberType(TypeKind::FloatType) {} + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "float"; + } +}; + +struct ComplexType; +using ComplexTypePtr = SingletonTypePtr; +// This type represents a Python float number +struct TORCH_API ComplexType : public NumberType { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "complex"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + } + static const TypeKind Kind = TypeKind::ComplexType; + // global singleton + static ComplexTypePtr get(); + + private: + ComplexType() : NumberType(TypeKind::ComplexType) {} + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "complex"; + } +}; + +// We need to introduce `SymIntType` to represent the `SymInt` type +// used in function schemas e.g. `aten::narrow_copy(... SymInt length) +// `SymInt` will be used to enable tracing arithmetic operations on +// dimension values. Please see [SymInt.h] for more information +struct SymIntType; +using SymIntTypePtr = SingletonTypePtr; +struct TORCH_API SymIntType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "SymInt"; + } + std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override { + return "int"; + } + static const TypeKind Kind = TypeKind::SymIntType; + // global singleton + static SymIntTypePtr get(); + + private: + SymIntType() : Type(TypeKind::SymIntType) {} +}; + +struct SymFloatType; +using SymFloatTypePtr = SingletonTypePtr; +struct TORCH_API SymFloatType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "SymFloat"; + } + std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override { + return "float"; + } + static const TypeKind Kind = TypeKind::SymFloatType; + // global singleton + static SymFloatTypePtr get(); + + private: + SymFloatType() : Type(TypeKind::SymFloatType) {} +}; + +struct SymBoolType; +using SymBoolTypePtr = SingletonTypePtr; +struct TORCH_API SymBoolType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "SymBool"; + } + std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override { + return "bool"; + } + static const TypeKind Kind = TypeKind::SymBoolType; + // global singleton + static SymBoolTypePtr get(); + + private: + SymBoolType() : Type(TypeKind::SymBoolType) {} +}; + +struct IntType; +using IntTypePtr = SingletonTypePtr; +// This type represents a Python int number +struct TORCH_API IntType : public NumberType { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "int"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); + } + static const TypeKind Kind = TypeKind::IntType; + // global singleton + static IntTypePtr get(); + + private: + IntType() : NumberType(TypeKind::IntType) {} + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "int"; + } +}; + +struct BoolType; +using BoolTypePtr = SingletonTypePtr; +// This node represents a Python bool value +struct TORCH_API BoolType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "bool"; + } + static const TypeKind Kind = TypeKind::BoolType; + // global singleton + static BoolTypePtr get(); + + private: + BoolType() : Type(TypeKind::BoolType) {} +}; + +struct StringType; +using StringTypePtr = SingletonTypePtr; +// This type represents a Python string +struct TORCH_API StringType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + // we only use "str" (not "string") in both FunctionSchema and script + return annotation_str(); + } + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "str"; + } + static const TypeKind Kind = TypeKind::StringType; + // global singleton + static StringTypePtr get(); + + private: + StringType() : Type(TypeKind::StringType) {} +}; + +struct StorageType; +using StorageTypePtr = SingletonTypePtr; +struct TORCH_API StorageType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return annotation_str(); + } + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + return "Storage"; + } + static const TypeKind Kind = TypeKind::StorageType; + // global singleton + static StorageTypePtr get(); + + private: + StorageType() : Type(TypeKind::StorageType) {} +}; + +struct FunctionType; +using FunctionTypePtr = std::shared_ptr; +struct TORCH_API FunctionType : public NamedType { + static FunctionTypePtr create(torch::jit::Function* function) { + return FunctionTypePtr( + new FunctionType(function)); // NOLINT(modernize-make-shared) + } + bool equals(const Type& rhs) const override { + if (auto func_type = rhs.cast()) { + return func_type->function_ == function_; + } + + return false; + } + std::string str() const override { + return "Function"; + } + torch::jit::Function* function() const { + return function_; + } + static const TypeKind Kind = TypeKind::FunctionType; + + private: + FunctionType(torch::jit::Function* function); + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name()->qualifiedName(); + } + torch::jit::Function* function_; +}; + +struct NoneType; +using NoneTypePtr = SingletonTypePtr; +// This type represents a Python None +struct TORCH_API NoneType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "NoneType"; + } + bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override; + + static const TypeKind Kind = TypeKind::NoneType; + // global singleton + static NoneTypePtr get(); + + private: + NoneType() : Type(TypeKind::NoneType) {} +}; + +struct GeneratorType; +using GeneratorTypePtr = SingletonTypePtr; +// This type represents a Generator +struct TORCH_API GeneratorType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Generator"; + } + static const TypeKind Kind = TypeKind::GeneratorType; + // global singleton + static GeneratorTypePtr get(); + + private: + GeneratorType() : Type(TypeKind::GeneratorType) {} +}; + +struct QuantizerType; +using QuantizerTypePtr = SingletonTypePtr; +// This type represents a Quantizer +struct TORCH_API QuantizerType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Quantizer"; + } + static const TypeKind Kind = TypeKind::QuantizerType; + // global singleton + static QuantizerTypePtr get(); + + private: + QuantizerType() : Type(TypeKind::QuantizerType) {} +}; + +struct QSchemeType; +using QSchemeTypePtr = SingletonTypePtr; +// This type represents a QScheme +struct TORCH_API QSchemeType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "QScheme"; + } + static const TypeKind Kind = TypeKind::QSchemeType; + // global singleton + static QSchemeTypePtr get(); + + private: + QSchemeType() : Type(TypeKind::QSchemeType) {} +}; + +struct DeviceObjType; +using DeviceObjTypePtr = SingletonTypePtr; +// This type represents a Device +struct TORCH_API DeviceObjType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Device"; + } + static const TypeKind Kind = TypeKind::DeviceObjType; + // global singleton + static DeviceObjTypePtr get(); + + private: + DeviceObjType() : Type(TypeKind::DeviceObjType) {} +}; + +struct StreamObjType; +using StreamObjTypePtr = SingletonTypePtr; +// This type represents a Generator +struct TORCH_API StreamObjType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Stream"; + } + static const TypeKind Kind = TypeKind::StreamObjType; + // global singleton + static StreamObjTypePtr get(); + +private: + StreamObjType() : Type(TypeKind::StreamObjType) {} +}; + +struct VarType; +using VarTypePtr = std::shared_ptr; +// This type represents a type variable, used in FunctionSchema +struct VarType : public SharedType { + static VarTypePtr create(std::string name_) { + return VarTypePtr(new VarType(std::move(name_))); + } + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return name(); + } + const std::string& name() const { + return name_; + } + bool hasFreeVariables() const override { + return true; + } + static const TypeKind Kind = TypeKind::VarType; + + private: + VarType(std::string name_) + : SharedType(TypeKind::VarType), name_(std::move(name_)) {} + std::string name_; +}; + +struct CapsuleType; +using CapsuleTypePtr = SingletonTypePtr; +// This type represents a Python Capsule. +// It does not appear in the IR and is only used during runtime +struct TORCH_API CapsuleType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Capsule"; + } + static const TypeKind Kind = TypeKind::CapsuleType; + // global singleton + static CapsuleTypePtr get(); +private: + CapsuleType() + : Type(TypeKind::CapsuleType) {} +}; + +struct PyObjectType; +using PyObjectTypePtr = SingletonTypePtr; +// This type represents a PyObject Type +struct TORCH_API PyObjectType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "PyObject"; + } + static const TypeKind Kind = TypeKind::PyObjectType; + // global singleton + static PyObjectTypePtr get(); +private: + PyObjectType() + : Type(TypeKind::PyObjectType) {} +}; + +enum class TypeVerbosity { + None, + Type, + TypeAndStride, + Full, + Symbolic, + Default = Full, +}; + +TORCH_API TypeVerbosity type_verbosity(); + +TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t); +template +TORCH_API std::ostream& operator<<( + std::ostream& out, + const VaryingShape& t); +TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s); +TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s); +TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s); +// what is the type, ignoring extra size/shape information? +// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) + +// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor +// subtypes as simply "Tensor"; we also create a new version of any +// container types in which internal Tensors have undergone the same +// operation. This is used for type comparisons between two Tensor types +// (`unshapedType` means that we don't falsely return `false` for e.g. +// Tensors of different dimensions). It's also used in the alias +// analysis pass. +// Be careful with calls because this can be very slow. If calling this +// on a graph, use `EraseShapeInformation` in shape_analysis.h +inline TypePtr unshapedType(const TypePtr& type) { + if (type->isSubtypeOf(*TensorType::get())) { + return TensorType::get(); + } + at::ArrayRef contained = type->containedTypes(); + if (contained.empty()) { + return type; + } + return type->withContained(fmap(type->containedTypes(), unshapedType)); +} + +inline TypePtr TensorType::fromNumberType(const Type& typ) { + if (typ.isSubtypeOf(*IntType::get())) { + return TensorType::createContiguous(at::kLong, at::kCPU, {}); + } else if (typ.isSubtypeOf(*FloatType::get())) { + return TensorType::createContiguous(at::kDouble, at::kCPU, {}); + } else if (typ.isSubtypeOf(*BoolType::get())) { + return TensorType::createContiguous(at::kBool, at::kCPU, {}); + } else if (typ.kind() == NumberType::Kind) { + return TensorType::create(std::nullopt, at::kCPU, {}, std::nullopt); + } + TORCH_CHECK(false, "Unknown number type: ", typ.str()); +} +inline TypePtr TensorType::fromBoolType() { + return TensorType::createContiguous(at::kBool, at::kCPU, {}); +} + +inline std::optional tryScalarTypeFromJitType(const Type& type) { + if (type == *FloatType::get()) { + return at::typeMetaToScalarType(c10::get_default_dtype()); + } else if (type == *IntType::get()) { + return at::ScalarType::Long; + } else if (type == *BoolType::get()) { + return at::ScalarType::Bool; + } + return std::nullopt; +} + +inline at::ScalarType scalarTypeFromJitType(const Type& type) { + auto result = tryScalarTypeFromJitType(type); + TORCH_CHECK( + result, + "Add new condition, expected Float, Complex, Int, or Bool but got", + type.str()); + return *result; +} + +// Attempt to find the correct supertype of the two types `t1` and `t2`. +// If no supertype is found, then nullopt will be returned if +// `default_to_union` is false, and `Union[t1, t2]` will be returned +// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`, +// then `t2` will be returned (and vice versa). +// +// Two different tensortypes will return dynamic. +// +// Currently we chose not to support returning a NumberType for +// two types from the set of {FloatType, IntType, ComplexType}, because +// there is a lack of operator support for NumberType. +// +// If `type_hint` is an `InterfaceType`, then we can use that as a +// potential supertype for `ClassType`s in the list. Otherwise, we have +// no way to find and use some common interface type +TORCH_API std::optional unifyTypes( + const TypePtr& t1, + const TypePtr& t2, + bool default_to_union = false, + const TypePtr& type_hint = nullptr); + +TORCH_API std::optional unifyTypeList( + at::ArrayRef elements, + std::ostream& why_not, + bool default_to_union = false, + const TypePtr& type_hint = nullptr); + +namespace detail { +template +struct getTypePtr_ final { + static decltype(auto) call() { + return ([]() { + try { + return getCustomClassType(); + } catch(const c10::Error&) { + TORCH_CHECK( + false, + "Type ", + c10::util::get_fully_qualified_type_name(), + " could not be converted to any of the known types." + ); + } + }()); + } +}; + +template +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return getTypePtr_::call(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return AnyType::get(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return TensorType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StorageType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StreamObjType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return FloatType::get(); + } +}; +template <> +struct getTypePtr_> final { + static decltype(auto) call() { + return ComplexType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; + +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return SymIntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; + +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return SymFloatType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return FloatType::get(); + } +}; + +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return SymBoolType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return BoolType::get(); + } +}; + +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return DeviceObjType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return BoolType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return NumberType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return QSchemeType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return TypeFactory::create( + TypeFactory::get()); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StringType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StringType::get(); + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return StringType::get(); + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per vector" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = ListType::get("vector", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per ArrayRef" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = ListType::get("ArrayRef", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_ final { + static const auto& call() { + static auto type = ListType::create(getMaybeFakeTypePtr_::call()); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per List" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = ListType::get("List", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + static auto type = ListType::get("List", inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per array" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + // (Concatenating the length onto the end of the string because we want a unique + // type_ptr created for every std::array type). + static auto type = ListType::get(std::string("array") + std::to_string(N), inner_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_key_type = getMaybeFakeTypePtr_::call(); + static auto inner_val_type = getMaybeFakeTypePtr_::call(); + // The "per unordered_map" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = DictType::get("unordered_map", inner_key_type, inner_val_type); + return type; + } +}; +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_key_type = getMaybeFakeTypePtr_::call(); + static auto inner_val_type = getMaybeFakeTypePtr_::call(); + // The "per Dict" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = DictType::get("Dict", inner_key_type, inner_val_type); + return type; + } +}; + +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per std::optional" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = OptionalType::get(inner_type); + return type; + } +}; + + +template<> +struct getTypePtr_ final { + static const auto& call() { + static auto inner_type = getMaybeFakeTypePtr_::call(); + // The "per std::optional" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto type = OptionalType::get(inner_type); + return type; + } +}; + +template +struct getMaybeFakeTypePtr_ final { + static const auto& call() { + // The "per std::optional" static singleton needs to live in a .cpp file, + // otherwise we'll end up with one singleton instance per shared library. + static auto inner_type = getMaybeFakeTypePtr_::call(); + static auto type = OptionalType::get(inner_type); + return type; + } +}; + +template +struct getMaybeFakeTypePtr_, fake> final { + static const auto& call() { + static auto type = ([]() { + std::vector contained_types = { + (getMaybeFakeTypePtr_::call())... + }; + return TupleType::create(std::move(contained_types)); + })(); + return type; + } +}; +template <> +struct getTypePtr_ final { + static decltype(auto) call() { + return NoneType::get(); + } +}; +} // namespace detail +template +inline decltype(auto) getTypePtr() { + // TODO: static_assert that a templated function exists, and throw a friendly + // error message if not + return detail::getMaybeFakeTypePtr_::call(); +} + +template +inline TypePtr getTypePtrCopy() { + // TODO: static_assert that a templated function exists, and throw a friendly + // error message if not + return getTypePtr(); +} + +template +inline decltype(auto) getFakeTypePtr() { + return detail::getMaybeFakeTypePtr_::call(); +} + +template +inline TypePtr getFakeTypePtrCopy() { + return getFakeTypePtr(); +} + +using TypeEnv = std::unordered_map; +struct MatchTypeReturn { + MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {} + static MatchTypeReturn Success() { + return MatchTypeReturn(); + } + bool success() const { + return !reason_.has_value(); + } + const std::string& reason() const { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return reason_.value(); + } + + private: + MatchTypeReturn() + : reason_(std::nullopt) {} + std::optional reason_; // is there is no match, this contains the reason +}; + +// attempt to match the type variables in formal to actual, adding them to type_env. +// If no match is possible this returns a MatchTypeReturn with r.success() == false +// and a r.reason() that describes why it could not match. +// note: It is possible to successfully match a formal, but for type variables +// in the formal to still not be defined. In particular, None matches Optional[T] +// but does not define the value of T. +TORCH_API MatchTypeReturn +matchTypeVariables(const TypePtr& formal, const TypePtr& actual, TypeEnv& type_env); + +// replace type variables appearing in `type` with the values in +// `type_env`. Returns nullptr if a variable used in `type` +// does not appear in `type_env` +TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, TypeEnv& type_env); + +TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type); + +struct InterfaceType; +using InterfaceTypePtr = std::shared_ptr; + +// Interfaces are a list of abstract methods that a class might meet. +// If a class provides those methods, it implicitly meets the interface. + +// Subtype relations for Interface with ClassType: +// lhs (ClassType or InterfaceType) is a subtype of rhs if: +// 1. lhs methods are a superset of rhs methods +// 2. if rhs is module interface, the lhs must be module interface or module itself +struct TORCH_API InterfaceType : public NamedType { + static InterfaceTypePtr create( + QualifiedName qualifiedName, bool is_module=false); + + bool equals(const Type& rhs) const override { + if (auto user_rhs = rhs.castRaw()) { + return isSubTypeImpl(*this, *user_rhs, nullptr) && + isSubTypeImpl(*user_rhs, *this, nullptr); + } + return false; + } + + std::string str() const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return std::string("InterfaceType<") + name()->name() + ">"; + } + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + // try to find a method of this interface, + // returns nullptr if not found. + const FunctionSchema* getMethod(const std::string& name) const; + void addMethod(FunctionSchema schema); + const std::vector& methods() const { + return *methods_; + } + + bool is_module() const override{ + return is_module_; + } + static const TypeKind Kind = TypeKind::InterfaceType; + ~InterfaceType() override = default; + private: + InterfaceType(QualifiedName name, bool is_module); + static bool isSubTypeImpl( + const InterfaceType& lhs, + const InterfaceType& rhs, + std::ostream* why_not); + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name()->qualifiedName(); + } + + // shared_ptr so that this header does not have to depend on + // FunctionSchema.h + std::shared_ptr> methods_; + // flag to distinguish if it's an interface type from a module or not + bool is_module_; +}; + +template +struct EnumerationType : public Type { +static const TypeKind Kind = K; + +bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); +} + +protected: +EnumerationType() : Type(Kind) {} +}; + +// WARNING: These enumeration types below DO NOT actually get parsed out +// from the logical schema strings, instead they are mapped as ints. To +// observe these types, use real_type() instead of type() on Argument + +struct ScalarTypeType; +using ScalarTypeTypePtr = SingletonTypePtr; +struct TORCH_API ScalarTypeType : public EnumerationType { +std::string str() const override { +return "ScalarType"; +} +static const TypeKind Kind = TypeKind::ScalarTypeType; +// global singleton +static ScalarTypeTypePtr get(); + +private: +ScalarTypeType() {} +}; + +struct MemoryFormatType; +using MemoryFormatTypePtr = SingletonTypePtr; +struct TORCH_API MemoryFormatType : public EnumerationType { +std::string str() const override { +return "MemoryFormat"; +} +static const TypeKind Kind = TypeKind::MemoryFormatType; +// global singleton +static MemoryFormatTypePtr get(); + +private: +MemoryFormatType() {} +}; + +struct LayoutType; +using LayoutTypePtr = SingletonTypePtr; +struct TORCH_API LayoutType : public EnumerationType { +std::string str() const override { +return "Layout"; +} +static const TypeKind Kind = TypeKind::LayoutType; +// global singleton +static LayoutTypePtr get(); + +private: +LayoutType() {} +}; + +namespace detail { +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return ScalarTypeType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return LayoutType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return MemoryFormatType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +} // namespace detail + +// the common supertype of all lists, +// List[T] <: AnyList for all T +struct AnyListType; +using AnyListTypePtr = SingletonTypePtr; +struct TORCH_API AnyListType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "list"; + } + static const TypeKind Kind = TypeKind::AnyListType; + // global singleton + static AnyListTypePtr get(); +private: + AnyListType() + : Type(TypeKind::AnyListType) {} +}; + +// the common supertype of all tuples, +// Tuple[T...] <: AnyTuple for all T +struct AnyTupleType; +using AnyTupleTypePtr = SingletonTypePtr; +struct TORCH_API AnyTupleType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + + std::string str() const override { + return "tuple"; + } + static const TypeKind Kind = TypeKind::AnyTupleType; + + // global singleton + static AnyTupleTypePtr get(); +private: + AnyTupleType() + : Type(TypeKind::AnyTupleType) {} +}; + +// the common supertype of all classes, +// ClassType <: AnyClassType for all classes +struct AnyClassType; +using AnyClassTypePtr = SingletonTypePtr; +struct TORCH_API AnyClassType : public Type { + bool equals(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "AnyClassType"; + } + static const TypeKind Kind = TypeKind::AnyClassType; + // global singleton + static AnyClassTypePtr get(); +private: + AnyClassType() + : Type(TypeKind::AnyClassType) {} +}; + +template<> +inline typename detail::CastReturnType::type Type::cast() { + if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || + kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; +} + +template<> +inline typename detail::CastConstReturnType::type Type::cast() const { + if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || + kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; +} + +template<> +inline const NamedType* Type::castRaw() const { + if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || + kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { + return static_cast(this); + } + return nullptr; +} + +// Used as a return type when inferring the IValue type of a Python object. +struct InferredType { + /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {} + /* implicit */ InferredType(std::string reason) + : type_(nullptr), reason_(std::move(reason)) {} + TypePtr type() const { + TORCH_INTERNAL_ASSERT( + type_, + "Tried to get the type from an InferredType but the type is null. ", + "Reason: ", + reason_); + return type_; + } + bool success() const { + return type_ != nullptr; + } + const std::string& reason() const { + TORCH_INTERNAL_ASSERT(!type_); + return reason_; + } + +private: + TypePtr type_; + std::string reason_; +}; + +TORCH_API bool containsAnyType(const TypePtr& type); + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/jit_type_base.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/jit_type_base.h new file mode 100644 index 0000000000000000000000000000000000000000..de440787ee686ff274721fb74d1192b0a014cc17 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/jit_type_base.h @@ -0,0 +1,721 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +#define C10_FORALL_TYPES(_) \ + _(AnyType) \ + _(EnumType) \ + _(AnyEnumType) \ + _(TensorType) \ + _(StorageType) \ + _(TupleType) \ + _(ListType) \ + _(DictType) \ + _(NumberType) \ + _(FloatType) \ + _(ComplexType) \ + _(FutureType) \ + _(AwaitType) \ + _(RRefType) \ + _(IntType) \ + _(NoneType) \ + _(StringType) \ + _(GeneratorType) \ + _(QuantizerType) \ + _(BoolType) \ + _(OptionalType) \ + _(VarType) \ + _(DeviceObjType) \ + _(StreamObjType) \ + _(FunctionType) \ + _(ClassType) \ + _(PyObjectType) \ + _(CapsuleType) \ + _(InterfaceType) \ + _(QSchemeType) \ + _(ScalarTypeType) \ + _(LayoutType) \ + _(MemoryFormatType) \ + _(AnyListType) \ + _(AnyTupleType) \ + _(AnyClassType) \ + _(SymIntType) \ + _(SymFloatType) \ + _(SymBoolType) \ + _(UnionType) \ + _(DynamicType) + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + C10_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +TORCH_API const char* typeKindToString(TypeKind kind); + +struct Type; +struct SharedType; + +// Use this to customize how a Type is printed using `annotation_str()`. If +// std::nullopt is returned, `annotation_str()` falls through to its default +// implementation. +using TypePrinter = std::function(const Type&)>; + +namespace detail { +template +struct IsSingletonType : public std::integral_constant {}; +} // namespace detail +#define TORCH_DECLARE_SINGLETON(Type) \ + struct Type; \ + namespace detail { \ + template <> struct IsSingletonType : public std::integral_constant {}; \ + } + +TORCH_DECLARE_SINGLETON(AnyType) +TORCH_DECLARE_SINGLETON(AnyEnumType) +TORCH_DECLARE_SINGLETON(NumberType) +TORCH_DECLARE_SINGLETON(FloatType) +TORCH_DECLARE_SINGLETON(ComplexType) +TORCH_DECLARE_SINGLETON(IntType) +TORCH_DECLARE_SINGLETON(BoolType) +TORCH_DECLARE_SINGLETON(StringType) +TORCH_DECLARE_SINGLETON(StorageType) +TORCH_DECLARE_SINGLETON(NoneType) +TORCH_DECLARE_SINGLETON(GeneratorType) +TORCH_DECLARE_SINGLETON(QuantizerType) +TORCH_DECLARE_SINGLETON(QSchemeType) +TORCH_DECLARE_SINGLETON(DeviceObjType) +TORCH_DECLARE_SINGLETON(StreamObjType) +TORCH_DECLARE_SINGLETON(CapsuleType) +TORCH_DECLARE_SINGLETON(PyObjectType) +TORCH_DECLARE_SINGLETON(ScalarTypeType) +TORCH_DECLARE_SINGLETON(LayoutType) +TORCH_DECLARE_SINGLETON(MemoryFormatType) +TORCH_DECLARE_SINGLETON(AnyListType) +TORCH_DECLARE_SINGLETON(AnyTupleType) +TORCH_DECLARE_SINGLETON(AnyClassType) + +namespace detail { +template +struct CastReturnType { + using type = std::shared_ptr; +}; + +template +struct CastReturnType::value>> { + using type = SingletonTypePtr; +}; + +template +struct CastConstReturnType { + using type = std::shared_ptr; +}; + +template +struct CastConstReturnType::value>> { + using type = SingletonTypePtr; +}; + +template +struct as_shared_type { + using type = SharedType*; +}; + +template +struct as_shared_type { + using type = const SharedType *; +}; +} // namespace detail + +struct TORCH_API Type { + friend TORCH_API bool operator==(const Type& lhs, const Type& rhs); + private: + TypeKind kind_; + + protected: + Type(TypeKind kind) : kind_(kind) {} + + Type(const Type&) = default; + Type& operator=(const Type&) = default; + Type(Type&&) noexcept = default; + Type& operator=(Type&&) noexcept = default; + + virtual std::string annotation_str_impl(const TypePrinter& /*printer*/) const { + return str(); + } + // a == b + virtual bool equals(const Type& rhs) const = 0; + // a == b <=> b == a + virtual bool symmetric() const { + return true; + } + + public: + template + class SingletonOrSharedTypePtr { + public: + using element_type = typename std::shared_ptr::element_type; + + SingletonOrSharedTypePtr() = default; + + /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr x) + : repr_(std::move(x)) {} + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr x) + : repr_(std::move(x)) {} + + /* implicit */ SingletonOrSharedTypePtr(std::nullptr_t) + : repr_(nullptr) {} + + /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p) + : repr_(p) {} + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p) + : repr_(SingletonTypePtr(p.get())) {} + + + // We need to support construction from T* for pybind. The problem + // is that it's not clear if we are supposed to be taking shared + // ownership or not. + // + // Case 1: if T is known statically to derive from SharedType, we should use + // shared_from_this() and take shared_ownership. + // + // Case 2: if T is exactly Type, we need to do a dynamic_cast to + // check if it's a SharedType and do the right thing. + // + // Case 3: Otherwise, T is not a SharedType. (debug-check this + // assumption!) Use a singleton pointer. + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast::type>(p)->shared_from_this()) {} + + template , bool> = true> + /* implicit */ SingletonOrSharedTypePtr(T* p) { + if (auto* shared_p = dynamic_cast::type>(p)) { + repr_ = Repr(shared_p->shared_from_this()); + } else { + repr_ = Repr(p); + } + } + + template && !std::is_base_of_v, bool> = true> + /* implicit */ SingletonOrSharedTypePtr(T* p) + : repr_(p) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast::type>(p) == nullptr); + } + + SingletonOrSharedTypePtr(const SingletonOrSharedTypePtr&) = default; + SingletonOrSharedTypePtr(SingletonOrSharedTypePtr&&) noexcept = default; + SingletonOrSharedTypePtr& operator=(const SingletonOrSharedTypePtr&) = default; + SingletonOrSharedTypePtr& operator=(SingletonOrSharedTypePtr&&) noexcept = default; + ~SingletonOrSharedTypePtr() = default; + + T* get() const { + return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast(repr_.rawRepr().first); + } + + operator bool() const { + return repr_.isNonNull(); + } + + bool operator==(std::nullptr_t) const { + return !repr_.isNonNull(); + } + + bool operator!=(std::nullptr_t) const { + return repr_.isNonNull(); + } + + template , void>, bool> = true> + U& operator*() const { + return *get(); + } + + T* operator->() const { + return get(); + } + + private: + // NOTE: SharedPtrWrapper exists to work around a baffling bug in + // nvcc; see comment in destroy() below. + struct SharedPtrWrapper { + SharedPtrWrapper(std::shared_ptr &&x) + : repr_(std::move(x)) {} + std::shared_ptr repr_; + }; + union Repr { + Repr() : Repr(nullptr) {} + + explicit Repr(std::shared_ptr x) + : shared_(std::move(x)) {} + + explicit Repr(std::nullptr_t) + : singletonRepr_(nullptr) {} + + explicit Repr(SingletonTypePtr p) + : singletonRepr_(p.get()) {} + + ~Repr() { + destroy(); + } + + // NOTE: the only non-UB way to access our null state is through + // rawRepr(), because our copy operation doesn't preserve which + // union member is active for null pointers. + Repr(const Repr& rhs) { + if (rhs.isSharedAndNonNull()) { + new (&shared_) SharedPtrWrapper(rhs.shared_); + } else { + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + } + + Repr(Repr&& rhs) noexcept { + if (rhs.isSharedAndNonNull()) { + new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); + } else { + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + } + + Repr& operator=(const Repr& rhs) { + if (&rhs == this) { + return *this; + } + if (rhs.isSharedAndNonNull()) { + if (isSharedAndNonNull()) { + shared_ = rhs.shared_; + } else { + new (&shared_) SharedPtrWrapper(rhs.shared_); + } + } else { + if (isSharedAndNonNull()) { + destroy(); + } + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + return *this; + } + + Repr& operator=(Repr&& rhs) noexcept { + if (&rhs == this) { + return *this; + } + if (rhs.isSharedAndNonNull()) { + if (isSharedAndNonNull()) { + shared_ = std::move(rhs.shared_); + } else { + new (&shared_) SharedPtrWrapper(std::move(rhs.shared_)); + } + } else { + if (isSharedAndNonNull()) { + destroy(); + } + singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr); + singletonRepr_.unused_ = nullptr; + } + return *this; + } + + SharedPtrWrapper shared_; + + struct SingletonRepr { + explicit SingletonRepr(T* s) : singleton_(s) {} + T* singleton_; + void* unused_ = nullptr; + } singletonRepr_; + struct RawRepr { + void* first; + void* nullIfSingleton_; + }; + + // It is UB to read the singleton part of Repr if it was + // constructed as a shared_ptr and vice versa, but memcpying out + // the representation is always OK, so here's an accessor to obey + // the letter of the law. + RawRepr rawRepr() const { + RawRepr repr{}; + memcpy(&repr, reinterpret_cast(this), sizeof(RawRepr)); + return repr; + } + + bool isNonNull() const { + auto repr = rawRepr(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr); + return repr.first != nullptr; + } + + bool isSharedAndNonNull() const { + return rawRepr().nullIfSingleton_ != nullptr; + } + + private: + void destroy() { + if (isSharedAndNonNull()) { + // Without SharedPtrWrapper, this line would read + // `shared_.~shared_ptr()` and nvcc would complain with + // "error: expected primary-expression before '>' token" + // referring to the "t" in "shared_ptr". SharedPtrWrapper + // exists to work around this compiler bug. + shared_.~SharedPtrWrapper(); + } + } + } repr_; + }; + + using TypePtr = SingletonOrSharedTypePtr; + using Ptr = TypePtr; + using ElementType = Type; + + // subtyping relation. By default, we return true for the case + // when the type is exactly equal or if this <: T where rhs = Optional[T] + + // if this returns false and the why_not stream is non-null, it contains + // additional details that describe why this is not a subtype of 'rhs'. + // This additional information should only contain details that are not + // obvious from the annotation_str() that describes the type. For instance it + // is clear that `int <: str` is false but not clear why `Foo <: InterfaceBar` + // might be false. + virtual bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const; + virtual bool is_module() const; + bool isSubtypeOf(const Type& rhs) const { + return isSubtypeOfExt(rhs, nullptr); + } + // Compatibility shims to accommodate existing code that passes shared_ptrs + // around. Ideally, we would just delete this, but it should be harmless. + template + std::enable_if_t, bool> + isSubtypeOf(const std::shared_ptr& rhs) const { + return isSubtypeOf(*rhs); + } + + template + std::enable_if_t, bool> + isSubtypeOf(const SingletonOrSharedTypePtr& rhs) const { + return isSubtypeOf(*rhs); + } + + template + std::enable_if_t, bool> + isSubtypeOf(SingletonTypePtr rhs) const { + return isSubtypeOf(*rhs); + } + + template + std::enable_if_t, bool> + isSubtypeOfExt(const SingletonOrSharedTypePtr& rhs, std::ostream* why_not) const { + return isSubtypeOfExt(*rhs, why_not); + } + + template + std::enable_if_t, bool> + isSubtypeOfExt(const std::shared_ptr& rhs, std::ostream* why_not) const { + return isSubtypeOfExt(*rhs, why_not); + } + + template + std::enable_if_t, bool> + isSubtypeOfExt(SingletonTypePtr rhs, std::ostream* why_not) const { + return isSubtypeOfExt(*rhs, why_not); + } + + // How this type will appear in FunctionSchema declarations + virtual std::string str() const = 0; + + // How this type will appear as if it were a type annotation in Python + // which is sometimes different than how it appears in declarations (e.g. + // int[] vs List[int]) + // + // Takes a custom printer that users can pass in to customize the output of + // this method. + std::string annotation_str(const TypePrinter& printer) const { + if (printer) { + // the printer can return std::nullopt to fall through to the default impl + if (auto renamed = printer(*this)) { + return *renamed; + } + } + return annotation_str_impl(printer); + } + std::string annotation_str() const { + // Overload instead of define a default value for `printer` to help + // debuggers out. + return annotation_str(nullptr); + } + + // Returns a human readable string that includes additional information like + // "type is inferred rather than explicitly defined" to help construct more + // user-friendly messages. + virtual std::string repr_str() const { + return annotation_str(); + } + + TypeKind kind() const { + return kind_; + } + + virtual bool isUnionType() const { + return false; + } + + virtual bool requires_grad() const { + for (const auto& ct : containedTypes()) { + if (ct->requires_grad()) { + return true; + } + } + return false; + } + + // Dynamically cast this object to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid. + template ::value, bool> = true> + typename detail::CastReturnType::type cast() { + if (T::Kind == kind()) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; + } + template ::value, bool> = true> + typename detail::CastReturnType::type cast() { + if (T::Kind == kind()) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get()); + return typename detail::CastReturnType::type(static_cast(this)); + } + return nullptr; + } + template ::value, bool> = true> + typename detail::CastConstReturnType::type cast() const { + if (T::Kind == kind()) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return nullptr; + } + template ::value, bool> = true> + typename detail::CastConstReturnType::type cast() const { + if (T::Kind == kind()) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get()); + return typename detail::CastConstReturnType::type(static_cast(this)); + } + return nullptr; + } + template + T* castRaw() { + if (T::Kind == kind()) { + return static_cast(this); + } + return nullptr; + } + template + const T* castRaw() const { + if (T::Kind == kind()) { + return static_cast(this); + } + return nullptr; + } + template + auto expect() { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + auto expect() const { + auto r = cast(); + AT_ASSERT(r); + return r; + } + template + T& expectRef() { + auto* r = castRaw(); + AT_ASSERT(r); + return *r; + } + template + const T& expectRef() const { + auto* r = castRaw(); + AT_ASSERT(r); + return *r; + } + virtual ~Type() = default; + virtual bool hasFreeVariables() const { + return false; + } + // list of types this type contains, e.g. for a List then element type of a + // list for a tuple, the types of the tuple elements + virtual at::ArrayRef containedTypes() const { + return {}; + } + virtual TypePtr containedType(size_t i) const { + return containedTypes().at(i); + } + virtual size_t containedTypeSize() const { + return containedTypes().size(); + } + // create a new version of this type, replacing its contained types with + // contained_types + TypePtr withContained(std::vector contained_types); + // per-type constructor, you only need to override this if the + // containedTypes() is not empty + virtual TypePtr createWithContained( + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector /*contained_types*/) const { + TORCH_CHECK(false, + "type with contained types did not overload createWithContained: ", + str()); + } + +}; + +template +using SingletonOrSharedTypePtr = Type::SingletonOrSharedTypePtr; + + +template +bool operator==(const SingletonOrSharedTypePtr& x, const SingletonOrSharedTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const SingletonOrSharedTypePtr& x, const std::shared_ptr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const std::shared_ptr& x, const SingletonOrSharedTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const SingletonOrSharedTypePtr& x, const SingletonTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator==(const SingletonTypePtr& x, const SingletonOrSharedTypePtr& y) { + return (void*)x.get() == (void*)y.get(); +} + +template +bool operator!=(const SingletonOrSharedTypePtr& x, const SingletonOrSharedTypePtr& y) { + return !(x == y); +} + +template +bool operator!=(const SingletonOrSharedTypePtr& x, const std::shared_ptr& y) { + return !(x == y); +} + +template +bool operator!=(const std::shared_ptr& x, const SingletonOrSharedTypePtr& y) { + return !(x == y); +} + +template +bool operator!=(const SingletonOrSharedTypePtr& x, const SingletonTypePtr& y) { + return !(x == y); +} + +template +bool operator!=(const SingletonTypePtr& x, const SingletonOrSharedTypePtr& y) { + return !(x == y); +} + +using TypePtr = SingletonOrSharedTypePtr; +using ConstTypePtr = SingletonOrSharedTypePtr; + +// Explicitly enable MaybeOwned>, rather than allowing +// MaybeOwned to be used for any type right away. +template +struct MaybeOwnedTraits> + : public MaybeOwnedTraitsGenericImpl> {}; + +// Base class for Types that are guaranteed to be owned by std::shared_ptr. +struct TORCH_API SharedType : public Type, public std::enable_shared_from_this { + using Type::Type; +}; + +inline TypePtr Type::withContained(std::vector contained_types) { + auto current_contained = containedTypes(); + // Types with no contained_types don't need this call. Check before calling! + // + // (We can't support this efficiently because types without + // contained types may be singletons, in which case + // shared_from_this will crash; we would have to provide a virtual + // typeptr_from_this or isSingleton.) + TORCH_INTERNAL_ASSERT(!current_contained.empty() && current_contained.size() == contained_types.size()); + if (current_contained.equals(contained_types)) { + return std::static_pointer_cast(static_cast(this)->shared_from_this()); + } + return createWithContained(std::move(contained_types)); +} + + +TORCH_API inline bool operator==(const Type& lhs, const Type& rhs) { + if (C10_UNLIKELY(!rhs.symmetric())) { + return rhs.equals(lhs); + } + return lhs.equals(rhs); +} + +struct NamedType; +using NamedTypePtr = std::shared_ptr; +using ConstNamedTypePtr = std::shared_ptr; + +struct TORCH_API NamedType : public SharedType { + NamedType(TypeKind tk, std::optional name) + : SharedType(tk), name_(std::move(name)) { + TORCH_INTERNAL_ASSERT( + tk == TypeKind::TupleType || tk == TypeKind::FunctionType || + tk == TypeKind::ClassType || tk == TypeKind::InterfaceType || + tk == TypeKind::EnumType, + "If you add a new kind of NamedType, ", + "please update the cast specialization and this assert"); + } + + // Fully qualified name of type + // Looks like: "foo.bar.Baz". + const std::optional& name() const { + return name_; + } + + private: + std::optional name_; +}; + +} // namespace c10 + +namespace std { +template +struct hash> { + size_t operator()(const c10::SingletonOrSharedTypePtr& x) const { + return std::hash()(x.get()); + } +}; +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/adaption.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/adaption.h new file mode 100644 index 0000000000000000000000000000000000000000..89ed8a419a03e771a6f35ac7dbbc47e29c99b852 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/adaption.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include +#include + +/* + * [Note: hacky wrapper removal for optional tensor] + * + * The kernel implementation takes an optional tensor marked in the schema as + * Tensor? but the C++ function takes Tensor instead of the std::optional + * expected by the dispatcher. + * + * To remove the hacky wrapper, the C++ function is changed to take + * std::optional and unwrap the Tensor value at the beginning of + * the function, e.g.: + * > c10::MaybeOwned weight_maybe_owned = + * > at::borrow_from_optional_tensor(weight_opt); + * > const Tensor& weight = *weight_maybe_owned; + * + * We may want to make the kernel handle optional directly without + * going through the creation of a default-constructed Tensor in + * at::borrow_from_optional_tensor. + */ + +/* + * [Note: hacky wrapper removal for TensorOptions] + * + * The kernel implementation takes a TensorOptions argument but the dispatcher + * expects separate arguments for dtype, layout, device, pin_memory. + * + * To remove the hacky wrapper, the kernel implementation is changed to take + * the 4 arguments (dtype, layout, device, pin_memory), and assemble the + * TensorOptions value at the beginning of the function, e.g.: + * > TensorOptions options = TensorOptions().dtype(dtype).layout(layout) + * > .device(device).pinned_memory(pin_memory); + * + * We may want make the kernel handle these parameters directly without going + * through the creation of a TensorOptions value. + */ + +namespace c10::impl { + +TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName); + +inline void check_and_update_common_device(std::optional& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) { + // TODO: Remove this once the following issue is addressed: + // https://github.com/pytorch/pytorch/issues/57380 + if (!tensor.defined()) { + return; + } + + if (!common_device.has_value()) { + common_device = tensor.device(); + return; + } + + if (C10_UNLIKELY(common_device != tensor.device())) { + common_device_check_failure(*common_device, tensor, methodName, argName); + } +} + +inline void check_and_update_common_device(std::optional& common_device, const std::optional& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) { + if (tensor.has_value()) { + check_and_update_common_device(common_device, tensor.value(), methodName, argName); + } +} + +inline void check_and_update_common_device(std::optional& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { + for (const auto& tensor : tensors) { + check_and_update_common_device(common_device, tensor, methodName, argName); + } +} + +inline void check_and_update_common_device(std::optional& common_device, const List>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { + for (const auto& tensor : tensors) { + check_and_update_common_device(common_device, tensor, methodName, argName); + } +} +} // namespace c10::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/infer_schema.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/infer_schema.h new file mode 100644 index 0000000000000000000000000000000000000000..a393e0290458c7bcf2bce653d19ebe1ccf8a38c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/infer_schema.h @@ -0,0 +1,157 @@ +#pragma once + +/** + * This file contains functionality to take a C++ function and infer its + * c10::FunctionSchema. + */ + +#include +#include + +namespace c10 { +namespace detail::infer_schema { + +/// The templated inference code creates `ArgumentDef` instead of `Argument`, +/// because that can be constructed at compile time and has a much smaller +/// binary size than having calls to `Argument` constructors in the template. +/// Creating `Argument` objects from `ArgumentDef` can then be done at +/// runtime in a non-templated way. +struct ArgumentDef final { + using GetTypeFn = TypePtr(); + GetTypeFn* getTypeFn; + GetTypeFn* getFakeTypeFn; + constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {} + explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {} +}; + +template +struct bool_t {}; +template<> struct bool_t : std::true_type {}; +template<> struct bool_t : std::false_type {}; + +/// Checks the static C++ types `Types` for correctness to catch common error cases. +template +constexpr int checkStaticTypes() { + // Give nice error messages for some of the common error cases. + // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT + static_assert(std::conjunction_v< + bool_t || std::is_same_v || std::is_same_v || std::is_same_v>... + >, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type"); + static_assert(std::conjunction_v< + bool_t>... + >, "INVALID TYPE: float is not supported as an argument type, use double instead"); + return 0; +} + +template +constexpr std::array createArgumentVectorFromTypes(std::index_sequence) { + return ( + // Check types for common errors + checkStaticTypes(), + + // Create the return value + std::array{ + ArgumentDef(&getTypePtrCopy>, &getFakeTypePtrCopy>)...} + ); +} + +/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified +/// as template arguments. +template struct createArguments final {}; +template +struct createArguments> final { + static constexpr std::array call() { + return createArgumentVectorFromTypes( + std::make_index_sequence() + ); + } +}; + +/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified +/// as a tuple (i.e. in the way c10 kernels return values). +/// It can be a tuple if there's three output arguments with types A, B, C. +/// It can be an empty tuple<>, or void for kernels that don't return anything. +/// It can be a single type A (i.e. no tuple) for the case where a kernel just +/// returns one value. +template struct createReturns final {}; + +template +struct createReturns, void> final { + static constexpr std::array call() { + return createArgumentVectorFromTypes( + std::make_index_sequence() + ); + } +}; + +template +struct createReturns && !guts::is_instantiation_of::value>> final { + static constexpr std::array call() { + return createReturns>::call(); + } +}; + +template<> +struct createReturns final { + static constexpr std::array call() { + return createReturns>::call(); + } +}; + +template +struct createSingleReturn { + static constexpr std::array call() { + return createArgumentVectorFromTypes(std::make_index_sequence<1>()); + } +}; + +TORCH_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef arguments, c10::ArrayRef returns); +TORCH_API FunctionSchema make_function_schema(c10::ArrayRef arguments, c10::ArrayRef returns); + +/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a +/// function. Flattens std::tuple returns into multiple return types +template +FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() { + using ReturnType = typename FunctionTraits::return_type; + using ParameterTypes = typename FunctionTraits::parameter_types; + + // arguments and returns are computed into a std::array at compile time and embedded into the binary. + // The only code executed at runtime here is the one that creates a std::vector + // of the arguments/returns from the std::array. + constexpr auto arguments = createArguments::call(); + constexpr auto returns = createReturns::call(); + + return make_function_schema(arguments, returns); +} + +/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a +/// function. Preserves std::tuple returns as a Tuple return type +template +FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) { + using ReturnType = typename FunctionTraits::return_type; + using ParameterTypes = typename FunctionTraits::parameter_types; + + // arguments and returns are computed into a std::array at compile time and embedded into the binary. + // The only code executed at runtime here is the one that creates a std::vector + // of the arguments/returns from the std::array. + constexpr auto arguments = createArguments::call(); + constexpr auto returns = createSingleReturn::call(); + + return make_function_schema(std::move(name), std::move(overload_name), arguments, returns); +} + +} + +template +FunctionSchema inferFunctionSchemaFlattenedReturns() { + return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns>(); +} + +template +FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) { + return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn>(std::move(name), std::move(overload_name)); +} + +TORCH_API std::optional findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h new file mode 100644 index 0000000000000000000000000000000000000000..3e8e03f9fa4c2581dd11cac83a81766cb936cb49 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h @@ -0,0 +1,181 @@ +#pragma once + +// TODO: unify to C10_MOBILE. In theory this header could be used in OSS. +#ifdef TEMPLATE_SELECTIVE_BUILD +#include +#endif + +/** + * This header implements functionality to build PyTorch with only a certain + * set of operators (+ dependencies) included. + * + * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these + * two ops will be included in your build. The allowlist records operators + * only, no overloads; if you include aten::add, all overloads of aten::add + * will be included. + * + * Internally, this is done by removing the operator registration calls + * using compile time programming, and the linker will then prune all + * operator functions that weren't registered. + * See Note [Selective build] for more details + * + * WARNING: The allowlist mechanism doesn't work for all ways you could go about + * registering an operator. If the dispatch key / operator name is not + * sufficiently obvious at compile time, then the allowlisting mechanism + * will fail (and the operator will be included in the binary anyway). + */ + +#include +#include +#include + + +#if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) +#include +#endif + +namespace c10::impl { + +constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item); // Forward Declare + +/** + * In selective build mode returns true/false depending on whether a build + * feature is available or not. + * + * In instrumenting mode (tracing mode), always returns true, and doesn't + * trigger any side effects. + */ +constexpr bool is_build_feature_available(const char* name) { +#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) + // Selective Build mode. +#if !defined(TORCH_BUILD_FEATURE_ALLOWLIST) + (void)name; + return true; +#else + return allowlist_contains( + C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST), + name); +#endif + +#else + // Instrumenting mode. + (void)name; + return true; +#endif +} + +[[noreturn]] void build_feature_required_feature_not_available(const char* feature); + +/** + * Use BUILD_FEATURE_REQUIRED macro in user-code. + * + * In selective build mode becomes a no-op if the build feature passed + * in is available. If not available, throws an exception (c10::Error). + * The compiler is able to perform dead code elimination for code + * following this method if the build feature is not available. + * + * In instrumenting mode (tracing mode), registers (as a side effect) + * the presence of this specific build feature being triggered. + */ +#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) // selective build mode + +#if defined(TORCH_BUILD_FEATURE_ALLOWLIST) +#define BUILD_FEATURE_REQUIRED(NAME) \ + if (!c10::impl::is_build_feature_available(NAME)) { \ + ::c10::impl::build_feature_required_feature_not_available(NAME); \ + } +#else // Everything trivially selected +#define BUILD_FEATURE_REQUIRED(NAME) + +#endif + +#else // trace mode +#define BUILD_FEATURE_REQUIRED(NAME) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::BUILD_FEATURE, \ + std::string(NAME), \ + {}); +#endif + +// Use this macro, and not is_build_feature_available +#define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME) + +// returns true iff allowlist contains item +// allowlist_contains("a;bc;d", "bc") == true +constexpr bool allowlist_contains(std::string_view allowlist, std::string_view item) { + //Choose a really big value for next so that if something goes wrong + //this code will blow up in a hopefully detectable way. + size_t next = std::numeric_limits::max(); + for (size_t cur = 0; cur <= allowlist.size(); cur = next) { + next = allowlist.find(';', cur); + if (next != std::string_view::npos) { + if (allowlist.substr(cur, next - cur) == item) { + return true; + } + next++; + } else { + if (allowlist.substr(cur).compare(item) == 0) { + return true; + } + break; + } + } + return false; +} + +// Returns true iff the given op name is on the allowlist +// and should be registered +constexpr bool op_allowlist_check(std::string_view op_name [[maybe_unused]]) { + assert(op_name.find("::") != std::string_view::npos); + // Use assert() instead of throw() due to a gcc bug. See: + // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function + // https://github.com/fmtlib/fmt/issues/682 + assert(op_name.find('(') == std::string_view::npos); +#if !defined(TORCH_OPERATOR_WHITELIST) + // If the TORCH_OPERATOR_WHITELIST parameter is not defined, + // all ops are to be registered + return true; +#else + return allowlist_contains( + C10_STRINGIZE(TORCH_OPERATOR_WHITELIST), + // This function is majorly used for mobile selective build with + // root operators, where the overload is included in the allowlist. + op_name); + // // Strip overload name (as allowlist doesn't contain overloads) + // // Another function based on this may be added when there's usage + // // on op names without overload. + // OperatorNameView::parse(op_name).name); +#endif +} + +// Returns true iff the given schema string is on the allowlist +// and should be registered +constexpr bool schema_allowlist_check(std::string_view schema) { +#if defined(TORCH_FORCE_SCHEMA_REGISTRATION) + return true; +#else + return op_allowlist_check(schema.substr(0, schema.find('('))); +#endif +} + +// Returns true iff the given custom class name is on the allowlist +// and should be registered +constexpr bool custom_class_allowlist_check(std::string_view custom_class_name [[maybe_unused]]) { +#if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST) + // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined, + // all custom classes are to be registered + return true; +#else + return allowlist_contains( + C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST), + custom_class_name); +#endif +} + +// schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST. +// Add this API to pass arbitrary allowlist. +constexpr bool op_allowlist_contains_name_in_schema(std::string_view allowlist, std::string_view schema) { + return allowlist_contains(allowlist, schema.substr(0, schema.find('('))); +} + +} // namespace c10::impl diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_registration.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_registration.h new file mode 100644 index 0000000000000000000000000000000000000000..7a44cfa49b0781af282f818ef4b5b16bc690b99f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/op_registration/op_registration.h @@ -0,0 +1,596 @@ +#pragma once + +/** + * Include this file if you want to register operators. It includes all + * functionality needed to do so for you. + */ + +#include +#include +#include +#include +#include +#include +#include +#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD) +#include +#endif +#include + +namespace c10 { + +namespace detail { +// The first argument of the schema might be of type DispatchKeySet, in which case we remove it. +// We do this because every argument in a function schema is expected to be convertable +// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of. +// See Note [Plumbing Keys Through The Dispatcher] +template +std::unique_ptr inferFunctionSchemaFromFunctor() { + using func_type = typename c10::remove_DispatchKeySet_arg_from_func::func_type; + return std::make_unique(inferFunctionSchemaFlattenedReturns()); +} +} + +/** + * An instance of this class handles the registration for one or more operators. + * Make sure you keep the RegisterOperators instance around since it will + * deregister the operator it's responsible for in its destructor. + * + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + */ +class TORCH_API RegisterOperators final { +public: + RegisterOperators() = default; + ~RegisterOperators() = default; + + RegisterOperators(const RegisterOperators&) = delete; + RegisterOperators& operator=(const RegisterOperators&) = delete; + RegisterOperators(RegisterOperators&&) noexcept = default; + RegisterOperators& operator=(RegisterOperators&&) noexcept = default; + + class TORCH_API Options final { + public: + Options(const Options&) = delete; + Options(Options&&) noexcept = delete; + Options& operator=(const Options&) = delete; + Options& operator=(Options&&) noexcept = delete; + + // internal-only for registering stack based kernels + template + Options&& kernel(DispatchKey dispatch_key) && { + return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction(), std::nullopt, nullptr); + } + + // internal-only for registering stack based catch-all kernels + template + Options&& catchAllKernel() && { + return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction(), std::nullopt, nullptr); + } + + // internal only for registering caffe2 ops + Options&& schema(FunctionSchema&& schema) { + TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration."); + schemaOrName_ = FunctionSchema(std::move(schema)); + return std::move(*this); + } + + /** + * Use this to specify the schema for an operator. You can also specify + * the operator name only to have the function signature part of the + * schema be inferred from the kernel function. + * + * Example: + * + * > // Infer function signature from my_kernel_cpu + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + * > + * > + * > // Explicitly specify full schema + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op(Tensor a) -> Tensor") + * > .kernel(DispatchKey::CPU)); + */ + Options&& schema(const std::string& schemaOrName) { + TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration."); + + #if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD) + throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build."); + #else + schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName); + #endif + + return std::move(*this); + } + + /** + * Use this to register an operator whose kernel is implemented as a functor. + * The kernel is only called for inputs matching the given dispatch key. + * You can register multiple kernels for different dispatch keys. + * + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + * + * The functor constructor can take arguments to configure the kernel. + * The arguments are defined in the kernel registration. + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b) + * > : ... {...} + * > + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU, "some_configuration", 3, true)); + */ + template + // enable_if: only enable it if KernelFunctor is actually a functor + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && { + static_assert(std::is_base_of_v, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + static_assert(std::is_constructible_v, "Wrong argument list for constructor of kernel functor. The arguments to kernel(arguments...) must match one of the constructors of Functor."); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedFunctor(std::make_unique(std::forward(constructorParameters)...)), + impl::CppSignature::make(), + detail::inferFunctionSchemaFromFunctor() + ); + } + + /** + * Use this to register an operator whose kernel is implemented as a functor. + * The kernel is a catch-all kernel, meaning it's called independent from + * the input. Dispatch is disabled for this operator. + * + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel()); + * + * The functor constructor can take arguments to configure the kernel. + * The arguments are defined in the kernel registration. + * Example: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > explicit my_kernel_cpu(std::string some_configuration, int a, bool b) + * > : ... {...} + * > + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel("some_configuration", 3, true)); + */ + template + // enable_if: only enable it if KernelFunctor is actually a functor + std::enable_if_t::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && { + static_assert(std::is_base_of_v, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); + static_assert(std::is_constructible_v, "Wrong argument list for constructor of kernel functor. The arguments to kernel(arguments...) must match one of the constructors of Functor."); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedFunctor(std::make_unique(std::forward(constructorParameters)...)), + impl::CppSignature::make(), + detail::inferFunctionSchemaFromFunctor() + ); + } + + /** + * Use this to register an operator whose kernel is implemented by a function. + * The kernel is only called for inputs matching the given dispatch key. + * You can register multiple kernels for different dispatch keys. + * + * Example: + * + * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU)); + */ + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key) && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>::type>() + ); + } + + /** + * Use this to register an operator whose kernel is implemented by a function. + * The kernel is a catch-all kernel, meaning it's called independent from + * the input. Dispatch is disabled for this operator. + * + * Example: + * + * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } + * > + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel()); + */ + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> catchAllKernel() && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>::type>() + ); + } + + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + template + // enable_if: only enable it if FuncType is actually a function + std::enable_if_t::value, Options&&> catchAllKernel(FuncType* kernel_func) && { + static_assert(!std::is_same_v, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); + TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr"); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + /** + * Use this to register an operator whose kernel is implemented as a lambda. + * The kernel is only called for inputs matching the given dispatch key. + * You can register multiple kernels for different dispatch keys. + * + * The lambda must be stateless, i.e. not have a capture. If your kernel + * needs to store some configuration parameters, write the kernel as a + * functor instead. + * + * Example: + * + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...})); + */ + template + // enable_if: only enable it if Lambda is a functor (note: lambdas are functors) + std::enable_if_t< + guts::is_functor>::value + && !std::is_same_v>::func_type, KernelFunction::BoxedKernelFunction>, + Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && { + static_assert(!std::is_base_of_v>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel() API instead."); + + // We don't support stateful lambdas (i.e. lambdas with a capture), because their + // behavior would be nonobvious. A functor kernel with cache gets a new instance of + // its cache each time the kernel is looked up from the dispatch table. + // A lambda with a capture would be global and share its capture between all kernel lookups. + // So, instead of making users having to think about it (including the thread-safety + // issues this causes), let's just forbid stateful lambdas altogether. + static_assert(guts::is_stateless_lambda>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel() instead."); + + return std::move(*this).kernel( + dispatch_key, + KernelFunction::makeFromUnboxedLambda(std::forward(functor)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + /** + * Use this to register an operator whose kernel is implemented as a lambda. + * The kernel is a catch-all kernel, meaning it's called independent from + * the input. Dispatch is disabled for this operator. + * + * The lambda must be stateless, i.e. not have a capture. If your kernel + * needs to store some configuration parameters, write the kernel as a + * functor instead. + * + * Example: + * + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .catchAllKernel([] (Tensor a) -> Tensor {...})); + */ + template + // enable_if: only enable it if Lambda is a functor (note: lambdas are functors) + std::enable_if_t< + guts::is_functor>::value + && !std::is_same_v>::func_type, KernelFunction::BoxedKernelFunction>, + Options&&> catchAllKernel(Lambda&& lambda) && { + static_assert(!std::is_base_of_v>, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel() API instead."); + + // We don't support stateful lambdas (i.e. lambdas with a capture), because their + // behavior would be nonobvious. + // A lambda with a capture would be global and share its capture between all kernel lookups. + // This would be a likely source for unexpected race conditions, so we forbid it. + // If a kernel really needs global state, they can just have regular global state + // in their .cpp file next to the kernel lambda. + static_assert(guts::is_stateless_lambda>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel() instead."); + + return std::move(*this).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedLambda(std::forward(lambda)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + ); + } + + Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && { + TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration."); + aliasAnalysisKind_ = aliasAnalysisKind; + return std::move(*this); + } + + private: + Options&& kernel(std::optional dispatch_key, KernelFunction&& func, std::optional cpp_signature, std::unique_ptr&& inferred_function_schema) && { + KernelRegistrationConfig config; + config.dispatch_key = dispatch_key; + config.func = std::move(func); + config.cpp_signature = cpp_signature; + config.inferred_function_schema = std::move(inferred_function_schema); + kernels.push_back(std::move(config)); + return std::move(*this); + } + + Options() + : schemaOrName_(std::nullopt) + , kernels() + , aliasAnalysisKind_(std::nullopt) + {} + + // KernelRegistrationConfig accumulates all information from the config + // parameters passed to a RegisterOperators::op() call into one object. + struct KernelRegistrationConfig final { + KernelRegistrationConfig() + : dispatch_key(std::nullopt) + , func() + , cpp_signature(std::nullopt) + , inferred_function_schema(nullptr) + {} + + std::optional dispatch_key; + KernelFunction func; + std::optional cpp_signature; + std::unique_ptr inferred_function_schema; + }; + + std::optional> schemaOrName_; + + std::vector kernels; + std::optional aliasAnalysisKind_; + friend class RegisterOperators; + friend class Library; + }; + + /** + * Call this to get an instance of registration options, which + * can be passed to a call to RegisterOperators::op() to specify + * these options for the operator registration. + * See class doc comment for examples. + */ + static Options options() { + return {}; + } + + /** + * Call this to register an operator. See class doc comment for examples. + */ + RegisterOperators&& op(Options&& options) && { + checkSchemaAndRegisterOp_(std::move(options)); + return std::move(*this); + } + + // Regular mutator version of the && version above + RegisterOperators& op(Options&& options) & { + checkSchemaAndRegisterOp_(std::move(options)); + return *this; + } + + /** + * This is a shorthand for RegisterOperators::op(Options) where you can + * specify the operator schema outside of the options parameter. + * See class doc comment for examples. + */ + RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && { + return std::move(*this).op(std::move(options).schema(schemaOrName)); + } + + // internal only for registering caffe2 ops + RegisterOperators&& op(FunctionSchema schema, Options&& options) && { + return std::move(*this).op(std::move(options).schema(std::move(schema))); + } + + template + explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options()) + : RegisterOperators() { + std::move(*this).op(schemaOrName, std::forward(func), std::move(options)); + } + + /** + * This API registers an operator based on a kernel function pointer. + * + * Given a kernel + * + * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } + * + * This API looks like: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", &my_kernel_cpu); + * + * If your kernel is small and the overhead of calling it matters, + * then this API might be the wrong choice since the following API + * has a slightly lower overhead for calling into the kernel: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", c10::RegisterOperators::options() + * > .kernel()); + * + * Or, alternatively, write your kernel as a functor: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", c10::RegisterOperators::options() + * > .kernel()); + */ + template + // enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction. + std::enable_if_t::value && !std::is_same_v, RegisterOperators&&> + op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && { + constexpr bool AllowLegacyTypes = true; + return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedRuntimeFunction(func), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + )); + } + + /** + * This API registers an operator based on a kernel lambda. + * + * This API looks like: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", [] (Tensor a, Tensor b) {...}); + * + * This is equivalent to: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", c10::RegisterOperators::options() + * > .catchAllKernel([] (Tensor a, Tensor b) {...})); + * + */ + template + // enable_if: only enable it if Lambda is actually a stateless lambda + std::enable_if_t::value && guts::is_stateless_lambda>::value, RegisterOperators&&> + op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && { + static_assert(!std::is_base_of_v, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead."); + + constexpr bool AllowLegacyTypes = true; + return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedLambda(std::forward(lambda)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + )); + } + + template + C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead.") + // enable_if: only enable it if Lambda is actually a functor but not a stateless lambda + std::enable_if_t::value && !guts::is_stateless_lambda>::value, RegisterOperators&&> + op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && { + static_assert(!std::is_base_of_v, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead."); + + constexpr bool AllowLegacyTypes = true; + return std::move(*this).op(std::move(options).schema(schemaOrName).kernel( + std::nullopt, + KernelFunction::makeFromUnboxedLambda(std::forward(lambda)), + impl::CppSignature::make(), + // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor + detail::inferFunctionSchemaFromFunctor>>() + )); + } + +private: + void checkSchemaAndRegisterOp_(Options&& config); + + static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options); + void checkNoDuplicateKernels_(const Options& options); + void registerOp_(Options&& options); + + std::vector registrars_; +}; + +} // namespace c10 + +namespace torch { + // Old-style API + using RegisterOperators = c10::RegisterOperators; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/operator_name.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/operator_name.h new file mode 100644 index 0000000000000000000000000000000000000000..22e1f427b632659aa6a523b9bcb2723a104414aa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/operator_name.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// TODO: consider storing namespace separately too +struct OperatorName final { + std::string name; + std::string overload_name; + OperatorName(std::string name, std::string overload_name) + : name(std::move(name)), overload_name(std::move(overload_name)) {} + + // TODO: These two functions below are slow! Fix internal data structures so + // I don't have to manually reconstruct the namespaces! + + // Return the namespace of this OperatorName, if it exists. The + // returned string_view is only live as long as the OperatorName + // exists and name is not mutated + std::optional getNamespace() const { + auto pos = name.find("::"); + if (pos == std::string::npos) { + return std::nullopt; + } else { + return std::string_view(name.data(), pos); + } + } + + // Returns true if we successfully set the namespace + bool setNamespaceIfNotSet(const char* ns) { + if (!getNamespace().has_value()) { + const auto ns_len = strlen(ns); + const auto old_name_size = name.size(); + name.resize(ns_len + 2 + old_name_size); + // Shift current value of name to the end of the new space. + name.replace( + name.size() - old_name_size, old_name_size, name, 0, old_name_size); + name.replace(0, ns_len, ns, ns_len); + name[ns_len] = ':'; + name[ns_len + 1] = ':'; + return true; + } else { + return false; + } + } +}; + +// Non-owning view of an OperatorName. Unlike OperatorName, most of +// its functions are constexpr, so it can be used for compile time +// computations +struct OperatorNameView final { + std::string_view name; + std::string_view overload_name; + constexpr OperatorNameView( + std::string_view name, + std::string_view overload_name) + : name(name), overload_name(overload_name) {} + // Parses strings like "foo.overload" and also "foo" + constexpr static OperatorNameView parse(std::string_view full_name) { + auto i = full_name.find('.'); + if (i == std::string_view::npos) { + return OperatorNameView(full_name, std::string_view()); + } else { + return OperatorNameView(full_name.substr(0, i), full_name.substr(i + 1)); + } + } +}; + +inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) { + return lhs.name == rhs.name && lhs.overload_name == rhs.overload_name; +} + +inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) { + return !operator==(lhs, rhs); +} + +TORCH_API std::string toString(const OperatorName& opName); +TORCH_API std::ostream& operator<<(std::ostream&, const OperatorName&); + +} // namespace c10 + +namespace std { +template <> +struct hash<::c10::OperatorName> { + size_t operator()(const ::c10::OperatorName& x) const { + return std::hash()(x.name) ^ + (~std::hash()(x.overload_name)); + } +}; +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/qualified_name.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/qualified_name.h new file mode 100644 index 0000000000000000000000000000000000000000..22fd8b8b857d7c1b60fbd4c5033a7765e281cab6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/qualified_name.h @@ -0,0 +1,161 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +// Represents a name of the form "foo.bar.baz" +struct QualifiedName { + QualifiedName() = default; + + // `name` can be a dotted string, like "foo.bar.baz", or just a bare name. + /* implicit */ QualifiedName(const std::string& name) { + TORCH_CHECK(!name.empty()); + // split the string into its atoms. + size_t startSearchFrom = 0; + size_t pos = name.find(delimiter_, startSearchFrom); + + while (pos != std::string::npos) { + auto atom = name.substr(startSearchFrom, pos - startSearchFrom); + TORCH_INTERNAL_ASSERT( + !atom.empty(), "Invalid name for qualified name: '", name, "'"); + atoms_.push_back(std::move(atom)); + startSearchFrom = pos + 1; + pos = name.find(delimiter_, startSearchFrom); + } + + auto finalAtom = name.substr(startSearchFrom); + TORCH_INTERNAL_ASSERT( + !finalAtom.empty(), "Invalid name for qualified name: '", name, "'"); + atoms_.emplace_back(std::move(finalAtom)); + + cacheAccessors(); + } + + explicit QualifiedName(std::vector atoms) : atoms_(std::move(atoms)) { + for (const auto& atom : atoms_) { + TORCH_CHECK(!atom.empty(), "Atom cannot be empty"); + TORCH_CHECK( + atom.find(delimiter_) == std::string::npos, + "Delimiter not allowed in atom"); + } + + cacheAccessors(); + } + // Unnecessary copy. Ideally we'd use something like std::string_view. + /* implicit */ QualifiedName(const char* name) + : QualifiedName(std::string(name)) {} + + // `name` must be a bare name (no dots!) + explicit QualifiedName(const QualifiedName& prefix, std::string name) { + TORCH_INTERNAL_ASSERT(!name.empty()); + TORCH_INTERNAL_ASSERT(name.find(delimiter_) == std::string::npos); + atoms_.insert(atoms_.begin(), prefix.atoms_.begin(), prefix.atoms_.end()); + atoms_.push_back(std::move(name)); + + cacheAccessors(); + } + + // Is `this` a prefix of `other`? + // For example, "foo.bar" is a prefix of "foo.bar.baz" + bool isPrefixOf(const QualifiedName& other) const { + const auto& thisAtoms = atoms_; + const auto& otherAtoms = other.atoms_; + + if (thisAtoms.size() > otherAtoms.size()) { + // Can't be a prefix if it's bigger + return false; + } + for (const auto i : c10::irange(thisAtoms.size())) { + if (thisAtoms[i] != otherAtoms[i]) { + return false; + } + } + return true; + } + + // The fully qualified name, like "foo.bar.baz" + const std::string& qualifiedName() const { + return qualifiedName_; + } + + // The leading qualifier, like "foo.bar" + const std::string& prefix() const { + return prefix_; + } + + // The base name, like "baz" + const std::string& name() const { + return name_; + } + + const std::vector& atoms() const { + return atoms_; + } + + bool operator==(const QualifiedName& other) const { + return this->qualifiedName_ == other.qualifiedName_; + } + + bool operator!=(const QualifiedName& other) const { + return !(*this == other); + } + + private: + static constexpr char delimiter_ = '.'; + + // Helper for cacheAccessors() below. + template + std::string join(char delimiter, const T& v) { + std::string out; + size_t reserve = 0; + for (const auto& e : v) { + reserve += e.size() + 1; + } + out.reserve(reserve); + for (const auto i : c10::irange(v.size())) { + if (i != 0) { + out.push_back(delimiter); + } + out.append(v[i]); + } + return out; + } + + void cacheAccessors() { + qualifiedName_ = join(delimiter_, atoms_); + if (atoms_.size() > 1) { + ArrayRef view(atoms_); + const auto prefixView = view.slice(0, view.size() - 1); + prefix_ = join(delimiter_, prefixView); + } + + if (!atoms_.empty()) { + name_ = atoms_.back(); + } + } + + // The actual list of names, like "{foo, bar, baz}" + std::vector atoms_; + + /* + * Cached accessors, derived from `atoms_`. + */ + std::string qualifiedName_; + std::string prefix_; + std::string name_; +}; +} // namespace c10 + +namespace std { +template <> +struct hash { + size_t operator()(const c10::QualifiedName& n) const noexcept { + return std::hash()(n.qualifiedName()); + } +}; +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/rref_interface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/rref_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..70273f168d936136556f7556cbdf398d6d792f58 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/rref_interface.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +namespace c10 { + +struct Type; +using worker_id_t = int16_t; + +// This abstract class contains only user-facing APIs, and will be shared +// between jit and distributed to implement TorchScript support. +class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target { + public: + RRefInterface() = default; + // RRef is made NOT copyable NOT movable to prevent messing up reference + // counting. + RRefInterface(const RRefInterface& other) = delete; + RRefInterface(RRefInterface&& other) = delete; + RRefInterface& operator=(const RRefInterface& other) = delete; + RRefInterface& operator=(RRefInterface&& other) = delete; + + ~RRefInterface() override = default; + + // returns the worker id of the owner + virtual worker_id_t owner() const = 0; + + // returns the worker name of the owner + virtual std::string ownerName() const = 0; + + // Returns true if this is the ``OwnerRRef`` + virtual bool isOwner() const = 0; + + // Returns true if this is an ``OwnerRRef`` or if this ``UserRRef`` has been + // confirmed by its owner. + virtual bool confirmedByOwner() const = 0; + + virtual const TypePtr type() const = 0; +}; + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/stack.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..ca2925f3cac20786353c28ad1fb2d7bbdb0101d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/stack.h @@ -0,0 +1,204 @@ +#pragma once + +#include + +#include +#include +#include + +// TODO move this to c10 namespace + + +namespace torch::jit { + +using c10::IValue; +using Stack = std::vector; + +class Operation { + template + using accepts = std::is_constructible, F&&>; + + public: + template ::value, int> = 0> + C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.") + Operation(F&& raw): op_([raw = std::forward(raw)](Stack& stack) { + raw(&stack); + }) {} + + template ::value && + !std::is_same_v, Operation>, int> = 0> + Operation(F&& op): op_(std::forward(op)) {} + + Operation(std::nullptr_t) noexcept {} + + explicit operator bool() const noexcept { + return op_ ? true : false; + } + + void operator()(Stack& stack) { + op_(stack); + } + + template + T* target() noexcept { + return op_.target(); + } + + private: + std::function op_; +}; + +// An operation with N inputs and M outputs pops the last N inputs off +// the stack and pushes its M inputs onto the stack +// before: I0, I1, ... IN <- stack.back() +// after: O0, O1, ... OM +// operations are defined this way so that ownership of inputs can be +// transferred to the operation and it can incrementally drop ownership of +// tensors when they become unneeded. For large operations, like 'run an entire +// subgraph', this functionality is very important for minimizing gpu memory +// usage return value is the relative 'offset' to jump to for the next +// operation: +// pc += 1 + offset +// so a return value of 0 goes to the next instruction + +// treat the last N elements of the stack as a list, looking up +// element i +inline IValue& peek(Stack& stack, size_t i, size_t N) { + // NOLINTNEXTLINE(*-narrowing-conversions) + return *(stack.end() - N + i); +} +inline IValue& peek(Stack* stack, size_t i, size_t N) { + return peek(*stack, i, N); +} +inline const IValue& peek(const Stack& stack, size_t i, size_t N) { + // NOLINTNEXTLINE(*-narrowing-conversions) + return *(stack.end() - N + i); +} +inline const IValue& peek(const Stack* stack, size_t i, size_t N) { + return peek(*stack, i, N); +} +// treat the last N elements of the stack as a list, looking up the +// slice starting at index i and having length len +inline at::ArrayRef peekSlice( + const Stack& stack, + size_t i, + size_t len, + size_t N) { + return at::ArrayRef(stack).slice(stack.size() - N + i, len); +} +inline at::ArrayRef last(const Stack& stack, size_t N) { + return peekSlice(stack, 0, N, N); +} +inline at::ArrayRef last(const Stack* stack, size_t N) { + return last(*stack, N); +} +inline void drop(Stack& stack, size_t n) { + // NOLINTNEXTLINE(*-narrowing-conversions) + stack.erase(stack.end() - n, stack.end()); +} +inline void drop(Stack* stack, size_t n) { + drop(*stack, n); +} +inline IValue pop(Stack& stack) { + TORCH_CHECK(!stack.empty(), "pop() called on empty stack"); + auto r = std::move(stack.back()); + stack.pop_back(); + return r; +} +inline IValue pop(Stack* stack) { + return pop(*stack); +} +inline std::vector pop(Stack& stack, size_t n) { + std::vector result; + result.reserve(n); + for (const auto i : c10::irange(n)) { + result.push_back(std::move(peek(stack, i, n))); + } + drop(stack, n); + return result; +} + +// variadic pop: +// int64_t a; at::Tensor b; +// pop(stack, a, b); +// equivalent to: +// b = pop(stack).toTensor(); +// a = pop(stack).toInt(); +template +inline void pop(Stack& stack, Types&... args) { + size_t i = 0; + constexpr size_t N = sizeof...(args); + (void)std::initializer_list{ + (args = std::move(peek(stack, i++, N)).template to(), 0)...}; + drop(stack, N); +} +template +inline void pop(Stack* stack, Types&... args) { + pop(*stack, args...); +} +template +inline void push_one(Stack& stack, Type&& arg) { + stack.emplace_back(std::forward(arg)); +} + +inline void push_one(Stack& stack, c10::TensorOptions options) { + stack.emplace_back(c10::typeMetaToScalarType(options.dtype())); + stack.emplace_back(options.layout()); + stack.emplace_back(options.device()); + stack.emplace_back(options.pinned_memory()); +} + +template +inline void push(Stack& stack, Types&&... args) { + (void)std::initializer_list{(push_one(stack, std::forward(args)), 0)...}; +} +template +inline void push(Stack* stack, Types&&... args) { + return push(*stack, std::forward(args)...); +} +template +inline void push_list_elements(Stack& stack, const c10::List& elements) { + for (T elem : elements) { + stack.push_back(std::move(elem)); + } +} + +// The packer here is carefully written not to make any unnecessary +// copies. + +// pack takes the return values of aten functions pushes them onto the stack +template +inline void pack(Stack& stack, T&& v) { + stack.emplace_back(std::forward(v)); +} +template +inline void pack(Stack* stack, T&& v) { + pack(*stack, std::forward(v)); +} + +template +struct TuplePacker { + // NB: *Not* a universal reference. + static void execute(Stack& stack, std::tuple&& t) { + // NB: The move here does not "destroy" the entire tuple, that is + // not what std::move does; only the particular tuple index + // processed here gets stolen. + pack(stack, std::get(std::move(t))); + TuplePacker::execute(stack, std::move(t)); + } +}; + +template +struct TuplePacker<0, Args...> { + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + static void execute(Stack& /*stack*/, std::tuple&& /*t*/){} +}; + +template +inline void pack(Stack& stack, std::tuple&& t) { + TuplePacker::execute(stack, std::move(t)); +} + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/symbol.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/symbol.h new file mode 100644 index 0000000000000000000000000000000000000000..f94cbf6d620ce062e64604bd8bcff366868e358b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/symbol.h @@ -0,0 +1,147 @@ +#pragma once +#include +#include +#include // For std::hash +#include + + +namespace c10 { + +// 'prim' symbols are synthetic operators that occur only in the IR +// and don't have corresponding implementations in ATen. + +// 'onnx' symbols correspond to ONNX operators. Their semantics +// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md +// The particular version we are targeting is specified by '_onnx_opset_version' +// in torch.onnx.symbolic_helper +// +// In general, most ONNX operators won't get an entry here, because they +// are handled from the Python end. However, you may occasionally need +// to intern an ONNX symbol here so that you can conveniently write an +// optimization on ONNX operations. + +// 'attr' symbols are attribute keys. They are shared between both ONNX and ATen +// operators (you disambiguate their meaning by looking at the operator itself). +// In general, you only need to define attribute keys that are used by +// onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS. + +// Note [Symbol allocation] +// ~~~~~~~~~~~~~~~~~~~~~~~~ +// +// 1. Symbol namespace is split up into namespaces. +// +// 2. The intended access pattern for built-in symbols is onnx::MatMul +// in the c10 namespace (this is a Symbol). +// + +// Built-in constant definition strategy: +// - Enum is the most convenient way to generate a contiguous sequence +// of numbers for an identifier. +// - However, an enum gives you a fresh type. We want onnx::MatMul to +// be type Symbol, not some random enum type! +// - Therefore, after using enums to generate the sequence of integers, +// we then declare constexpr Symbols to get everything the actual Symbol +// type we want. Symbols must be constexpr to be valid to be "case"ed on. + +using unique_t = uint32_t; + +const std::string& domain_prefix(); + +// A Symbol is like an interned string, but with a little extra +// structure; it is namespaced via SymbolNamespace and the resulting +// intern pointers support efficient namespace testing. +struct TORCH_API Symbol { + explicit constexpr Symbol() : value(0) {} + explicit constexpr Symbol(unique_t uniq) + : value(uniq) {} + + // Get a Symbol for a qualified string like "attr::bar" + static Symbol fromQualString(const std::string & s); + + // Get a Symbol from a domain and an unqualified string like "org.pytorch.attr" and "bar" + static Symbol fromDomainAndUnqualString(const std::string & d, const std::string & s); + + // Constructors for our various namespaced strings. This will construct + // the appropriate namespaced string, e.g., "attr::foo" for the + // argument "foo", and then attempt to intern it. DO NOT USE THIS + // with a string literal; attr::foo should be available in that case + // (and if it's not, you should add it to the built-ins list above.) + static Symbol attr(const std::string & s); + static Symbol aten(const std::string & s); + static Symbol cuda(const std::string & s); + static Symbol onnx(const std::string & s); + static Symbol prim(const std::string & s); + static Symbol user(const std::string & s); + static Symbol caffe2(const std::string & s); + static Symbol dimname(const std::string & s); + // TODO: eliminate me + static Symbol scope(const std::string & s); + + bool is_attr() const; + bool is_aten() const; + bool is_cuda() const; + bool is_prim() const; + bool is_prims() const; + bool is_nvprims() const; + bool is_onnx() const; + bool is_user() const; + bool is_caffe2() const; + bool is_dimname() const; + + // So we can switch on this + constexpr operator unique_t() const { + return value; + } + + Symbol ns() const; + + // Give a string corresponding to the unqualified version of this name, e.g., + // "mm". Use this in a context where the intended namespace of the string is + // obvious; this is a *lossy* conversion. + const char * toUnqualString() const; + + // Give a string corresponding to the qualified version of this name, + // e.g., "aten::mm". This string format is made available to Python bindings + // (so we know how to parse it.) + const char * toQualString() const; + + // This describes a symbol in a case where humans read it. At the moment it's + // the same as toQualString. This has to be a const char* returned because + // a lot of printf style macros use it. + const char * toDisplayString() const; + + // Give a string corresponding to the domain name for the symbol, + // e.g., "org.pytorch.aten". + std::string domainString() const; + +private: + + explicit Symbol(Symbol ns, const std::string & s); + unique_t value; +}; + +static inline bool operator==(Symbol lhs, Symbol rhs) { + return static_cast(lhs) == static_cast(rhs); +} + +inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); } +inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); } +inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); } +inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); } +inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); } +inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); } +inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); } +inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); } +inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); } + +} // namespace c10 + +// make symbol behave like an integer in hash tables +namespace std { +template <> +struct hash { + size_t operator()(c10::Symbol s) const { + return std::hash()(static_cast(s)); + } +}; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/type_factory.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/type_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..c2af2a1cb8451a344a68a40fd8cde203183f4050 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/type_factory.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +template +struct TORCH_API TypeFactoryBase {}; + +template <> +struct TORCH_API TypeFactoryBase { + template + static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) { + return std::make_shared( + c10::DynamicTypeTrait::tagValue(), + c10::DynamicType::Arguments(c10::ArrayRef( + {std::move(ty), std::forward(args)...}))); + } + template + static c10::DynamicTypePtr create(const std::vector& types) { + return std::make_shared( + c10::DynamicTypeTrait::tagValue(), + c10::DynamicType::Arguments(types)); + } + static c10::DynamicTypePtr createNamedTuple( + const std::string& name, + const std::vector& fields, + const std::vector& types) { + return std::make_shared( + c10::DynamicType::Tag::Tuple, + name, + c10::DynamicType::Arguments(fields, types)); + } + template + C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) { + return std::make_shared( + c10::DynamicTypeTrait::tagValue(), + name, + c10::DynamicType::Arguments{}); + } + template + C10_ERASE static decltype(auto) get() { + return DynamicTypeTrait::getBaseType(); + } + static const std::unordered_map& basePythonTypes(); +}; + +using DynamicTypeFactory = TypeFactoryBase; + +// Helper functions for constructing DynamicTypes inline. +template < + typename T, + std::enable_if_t::isBaseType, int> = 0> +C10_ERASE DynamicTypePtr dynT() { + return DynamicTypeFactory::get(); +} + +template < + typename T, + typename... Args, + std::enable_if_t::isBaseType, int> = 0> +C10_ERASE DynamicTypePtr dynT(Args&&... args) { + return DynamicTypeFactory::create(std::forward(args)...); +} + +template <> +struct TORCH_API TypeFactoryBase { + template + static c10::TypePtr create(TypePtr ty, Args&&... args) { + return T::create(std::move(ty), std::forward(args)...); + } + template + static c10::TypePtr create(std::vector types) { + return T::create(std::move(types)); + } + static c10::TypePtr createNamedTuple( + const std::string& name, + const std::vector& fields, + const std::vector& types); + template + C10_ERASE static c10::TypePtr createNamed(const std::string& name) { + return T::create(name); + } + static const std::unordered_map& basePythonTypes(); + template + C10_ERASE static c10::TypePtr get() { + return T::get(); + } +}; + +using DefaultTypeFactory = TypeFactoryBase; + +using PlatformType = +#ifdef C10_MOBILE + c10::DynamicType +#else + c10::Type +#endif + ; + +using TypeFactory = TypeFactoryBase; + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/type_ptr.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/type_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..0859e04c7d2d834155af005b1f574a5182eb51fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/type_ptr.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { + +// Compatibility wrapper around a raw pointer so that existing code +// written to deal with a shared_ptr can keep working. +template +class SingletonTypePtr { + public: + /* implicit */ SingletonTypePtr(T* p) : repr_(p) {} + + // We need this to satisfy Pybind11, but it shouldn't be hit. + explicit SingletonTypePtr(std::shared_ptr) { TORCH_CHECK(false); } + + using element_type = typename std::shared_ptr::element_type; + + template , void>, bool> = true> + T& operator*() const { + return *repr_; + } + + T* get() const { + return repr_; + } + + T* operator->() const { + return repr_; + } + + operator bool() const { + return repr_ != nullptr; + } + + private: + T* repr_{nullptr}; +}; + +template +bool operator==(SingletonTypePtr lhs, SingletonTypePtr rhs) { + return (void*)lhs.get() == (void*)rhs.get(); +} + +template +bool operator!=(SingletonTypePtr lhs, SingletonTypePtr rhs) { + return !(lhs == rhs); +} + +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/core/typeid.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/typeid.h new file mode 100644 index 0000000000000000000000000000000000000000..5967c0a1659aadd9225d3f16f13275879c8bdbc9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/core/typeid.h @@ -0,0 +1 @@ +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..388b3170d5b55a8c4bdd3af4ff982397fb323cb6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h new file mode 100644 index 0000000000000000000000000000000000000000..112121b297055fddc76a054456ada3e992034d92 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h @@ -0,0 +1,475 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at { +namespace detail { +// We prefer to convert through float for reduced-precision floating +// point types if we have a Vectorized specialization for float and we +// don't have one for the actual type in question. +template +struct should_prefer_converting_through_float + : std::bool_constant< + is_reduced_floating_point_v && + vec::is_vec_specialized_for_v && + !vec::is_vec_specialized_for_v> {}; + +template +constexpr auto should_prefer_converting_through_float_v = + should_prefer_converting_through_float::value; +} // namespace detail + +namespace vec { +// slow path +template +inline scalar_t vec_reduce_all( + const Op& vec_fun, + vec::Vectorized acc_vec, + int64_t size) { + using Vec = vec::Vectorized; + scalar_t acc_arr[Vec::size()]; + acc_vec.store(acc_arr); + for (const auto i : c10::irange(1, size)) { + std::array acc_arr_next = {0}; + acc_arr_next[0] = acc_arr[i]; + Vec acc_vec_next = Vec::loadu(acc_arr_next.data()); + acc_vec = vec_fun(acc_vec, acc_vec_next); + } + acc_vec.store(acc_arr); + return acc_arr[0]; +} + +template +struct VecReduceAllSIMD { + static inline scalar_t apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + return vec_reduce_all(vec_fun, acc_vec, Vectorized::size()); + } +}; + +#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \ + !defined(C10_MOBILE) +#if defined(CPU_CAPABILITY_AVX2) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + Vec v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = vec_fun(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = vec_fun(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = vec_fun(v, v1); + return _mm256_cvtss_f32(v); + } +}; +#endif // defined(CPU_CAPABILITY_AVX2) +#if defined(CPU_CAPABILITY_AVX512) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 256-bit shuffle + Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E); + v = vec_fun(v, v1); + // 128-bit shuffle + v1 = _mm512_shuffle_f32x4(v, v, 0xB1); + v = vec_fun(v, v1); + // 64-bit shuffle + v1 = _mm512_shuffle_ps(v, v, 0x4E); + v = vec_fun(v, v1); + // 32-bit shuffle + v1 = _mm512_shuffle_ps(v, v, 0xB1); + v = vec_fun(v, v1); + return _mm512_cvtss_f32(v); + } +}; +#endif // defined(CPU_CAPABILITY_AVX512) +#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && + // !defined(C10_MOBILE) + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + + // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, + // a4+a8, a1+a5, a2+a6, -, -, -, -] + float32x4_t v1_1 = vextq_f32(v, v, 2); + Vec v1 = v1_1; + // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] + v = vec_fun(v, v1); + + // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, + // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, + // -] + v1_1 = vrev64q_f32(v); + v1 = v1_1; + // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, + // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] + v = vec_fun(v, v1); + + return v[0]; + } +}; + +template <> +struct VecReduceAllSIMD>> { + static inline float apply( + const std::plus>& vec_fun, + const Vectorized& acc_vec) { + return vaddvq_f32(acc_vec); + } +}; +#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + // && !defined(CPU_CAPABILITY_SVE) + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + defined(CPU_CAPABILITY_SVE256) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + svuint32_t ind = svdupq_n_u32(4, 5, 6, 7); + Vec v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 64-bit shuffle + ind = svdupq_n_u32(2, 3, 0, 1); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 32-bit shuffle + ind = svdupq_n_u32(1, 0, 2, 3); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + return svlasta(svpfalse(), v); + } +}; +#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + // && defined(CPU_CAPABILITY_SVE256) + +template +inline scalar_t vec_reduce_all( + const Op& vec_fun, + const Vectorized& acc_vec) { + return VecReduceAllSIMD::apply(vec_fun, acc_vec); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline scalar_t reduce_all( + const Op& vec_fun, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) + return vec_reduce_all(vec_fun, Vec::loadu(data, size), size); + int64_t d = Vec::size(); + Vec acc_vec = Vec::loadu(data); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + acc_vec = vec_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(vec_fun, acc_vec); +} + +// similar to reduce_all, but reduces into two outputs +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + auto loaded_data = Vec::loadu(data, size); + return std::pair( + vec_reduce_all(vec_fun1, loaded_data, size), + vec_reduce_all(vec_fun2, loaded_data, size)); + } + int64_t d = Vec::size(); + Vec acc_vec1 = Vec::loadu(data); + Vec acc_vec2 = Vec::loadu(data); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + acc_vec1 = vec_fun1(acc_vec1, data_vec); + acc_vec2 = vec_fun2(acc_vec2, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d); + acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d); + } + return std::pair( + vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2)); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) + return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size); + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + data_vec = map_fun(data_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + data_vec = map_fun(data_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map2_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + Vec data_vec = Vec::loadu(data, size); + Vec data2_vec = Vec::loadu(data2, size); + data_vec = map_fun(data_vec, data2_vec); + return vec_reduce_all(red_fun, data_vec, size); + } + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + Vec data2_vec = Vec::loadu(data2 + d); + data_vec = map_fun(data_vec, data2_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + Vec data2_vec = Vec::loadu(data2 + d, size - d); + data_vec = map_fun(data_vec, data2_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map3_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + const scalar_t* data3, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + Vec data_vec = Vec::loadu(data, size); + Vec data2_vec = Vec::loadu(data2, size); + Vec data3_vec = Vec::loadu(data3, size); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + return vec_reduce_all(red_fun, data_vec, size); + } + + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + Vec data2_vec = Vec::loadu(data2 + d); + Vec data3_vec = Vec::loadu(data3 + d); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + Vec data2_vec = Vec::loadu(data2 + d, size - d); + Vec data3_vec = Vec::loadu(data3 + d, size - d); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v>, + int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec output_vec = vec_fun(Vec::loadu(input_data + d)); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d)); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map2( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + const scalar_t* input_data2, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(input_data + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(input_data + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map3( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec1 = Vec::loadu(input_data1 + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec data_vec3 = Vec::loadu(input_data3 + d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec1 = Vec::loadu(input_data1 + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec data_vec3 = Vec::loadu(input_data3 + d, size - d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map4( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + const scalar_t* input_data4, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec1 = Vec::loadu(input_data1 + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec data_vec3 = Vec::loadu(input_data3 + d); + Vec data_vec4 = Vec::loadu(input_data4 + d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec1 = Vec::loadu(input_data1 + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec data_vec3 = Vec::loadu(input_data3 + d, size - d); + Vec data_vec4 = Vec::loadu(input_data4 + d, size - d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4); + output_vec.store(output_data + d, size - d); + } +} + +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..5545a65cc733d856b21ae35f704f5540f9c0b8ea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h @@ -0,0 +1,647 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +namespace at::vec { +// BFloat16 specification +template +struct VecScalarType { + using type = scalar_t; +}; +template <> +struct VecScalarType { + using type = float; +}; +template <> +struct VecScalarType { + using type = float; +}; + +// This is different from at::acc_type since we only need to specialize BFloat16 +template +using vec_scalar_t = typename VecScalarType::type; + +// Vector conversion between float and bfloat16/half +template <> +inline std::tuple, Vectorized> convert_to_float< + BFloat16>(const Vectorized& a) { + return convert_bfloat16_float(a); +} + +template <> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized& a) { + return convert_half_float(a); +} + +template <> +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { + return convert_float_bfloat16(a, b); +} + +template <> +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { + return convert_float_half(a, b); +} + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float( + const scalar_t* data, + Vectorized& out1, + Vectorized& out2); + +template <> +inline void load_to_float( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_bf16(data, out1, out2); +} + +template <> +inline void load_to_float( + const Half* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_fp16(data, out1, out2); +} + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float(const scalar_t* data, Vectorized& out); + +template <> +inline void load_to_float( + const BFloat16* data, + Vectorized& out) { + load_fp32_from_bf16(data, out); +} + +template <> +inline void load_to_float(const Half* data, Vectorized& out) { + load_fp32_from_fp16(data, out); +} + +// Note that we already have specialized member of Vectorized for +// BFloat16 so the following functions would run smoothly: +// using Vec = Vectorized; +// Vec one = Vec(BFloat16(1)); +// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); +// +// Then why we still need to specialize "functional"? +// If we do specialization at Vectorized<> level, the above example would need +// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and +// "/". If we do specialization at vec::map<>() level, we have only 1 pair of +// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16 +// vector only. +// +// The following BFloat16 functionality will only do data type conversion for +// input and output vector (reduce functionality will only convert the final +// scalar back to bf16). Compared to Vectorized<> specialization, +// 1. better performance since we have less data type conversion; +// 2. less rounding error since immediate results are kept in fp32; +// 3. accumulation done on data type of fp32. +// +// If you plan to extend this file, please ensure adding unit tests at +// aten/src/ATen/test/vec_test_all_types.cpp +// +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + data_fvec0 = fVec::set( + data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(vec_fun, data_fvec0, fVec::size()); + } else { + return vec_reduce_all(vec_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); + acc_fvec1 = vec_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + acc_fvec0 = + fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(vec_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + fVec acc1_fvec = fVec::set( + data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc2_fvec = fVec::set( + data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); + return std::pair( + vec_reduce_all(vec_fun1, acc1_fvec, fVec::size()), + vec_reduce_all(vec_fun2, acc2_fvec, fVec::size())); + } else { + return std::pair( + vec_reduce_all(vec_fun1, data_fvec0, size), + vec_reduce_all(vec_fun2, data_fvec0, size)); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc1_fvec0, acc1_fvec1] = convert_to_float(acc_bvec); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc_bvec); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); + acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1); + acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); + acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); + acc1_fvec1 = fVec::set( + acc1_fvec1, + vec_fun1(acc1_fvec1, data_fvec1), + size - d - fVec::size()); + acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); + acc2_fvec1 = fVec::set( + acc2_fvec1, + vec_fun2(acc2_fvec1, data_fvec1), + size - d - fVec::size()); + } else { + acc1_fvec0 = + fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); + acc2_fvec0 = + fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); + } + } + acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); + acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1); + return std::pair( + vec_reduce_all(vec_fun1, acc1_fvec0), + vec_reduce_all(vec_fun2, acc2_fvec0)); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + acc_fvec0 = map_fun(acc_fvec0); + acc_fvec1 = map_fun(acc_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map2_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2, size); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + bVec acc2_bvec = bVec::loadu(data2); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec); + acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0); + acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map3_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + const scalar_t* data3, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2, size); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3, size); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + bVec acc2_bvec = bVec::loadu(data2); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec); + bVec acc3_bvec = bVec::loadu(data3); + auto [acc3_fvec0, acc3_fvec1] = convert_to_float(acc3_bvec); + acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0); + acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v>), + int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(input_data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(input_data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const float* input_data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + fVec data_fvec0 = fVec::loadu(input_data + d); + fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size()); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + fVec data_fvec0, data_fvec1; + if (size - d > fVec::size()) { + data_fvec0 = fVec::loadu(input_data + d); + data_fvec1 = + fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); + } else { + // choose to align with behaviour of bVec::loadu(ptr, size), + // which leaves data_fvec1 uninitialized + data_fvec0 = fVec::loadu(input_data + d, size - d); + } + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map2( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + const scalar_t* input_data2, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(input_data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(input_data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map3( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data1_bvec = bVec::loadu(input_data1 + d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); + fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); + fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map4( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + const scalar_t* input_data4, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data1_bvec = bVec::loadu(input_data1 + d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + bVec data4_bvec = bVec::loadu(input_data4 + d); + auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); + auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..f9086f7d3d0b9dda4147a32cea8aa604b53e0038 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h @@ -0,0 +1,55 @@ +#pragma once +#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC or clang-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) +/* Clang-compatible compiler, targeting arm neon */ +#include +#if defined(__ARM_FEATURE_SVE) +/* CLANG-compatible compiler, targeting ARM with SVE */ +#include +#endif +#elif defined(_MSC_VER) +/* Microsoft C/C++-compatible compiler */ +#include +#if _MSC_VER <= 1900 +#define _mm256_extract_epi64(X, Y) \ + (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) +#define _mm256_extract_epi32(X, Y) \ + (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) +#define _mm256_extract_epi16(X, Y) \ + (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) +#define _mm256_extract_epi8(X, Y) \ + (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) +#endif +#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) +/* GCC-compatible compiler, targeting ARM with NEON */ +#include +#if defined(__ARM_FEATURE_SVE) +/* GCC-compatible compiler, targeting ARM with SVE */ +#include +#endif +#if defined(MISSING_ARM_VLD1) +#include +#elif defined(MISSING_ARM_VST1) +#include +#endif +#elif defined(__GNUC__) && defined(__IWMMXT__) +/* GCC-compatible compiler, targeting ARM with WMMX */ +#include +#elif defined(__s390x__) +// targets Z/architecture +// we will include vecintrin later +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) +/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ +#include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel +#elif defined(__GNUC__) && defined(__SPE__) +/* GCC-compatible compiler, targeting PowerPC with SPE */ +#include +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..f3786019064c1ab71e1d8901edbe86dad61997e2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include + +#if defined(CPU_CAPABILITY_SVE) + +// Define the data type of VLS(vector-length specific). +typedef svbool_t vls_pred_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint8_t vls_int8_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint16_t vls_int16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint32_t vls_int32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint64_t vls_int64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint8_t vls_uint8_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint16_t vls_uint16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint32_t vls_uint32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint64_t vls_uint64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat16_t vls_float16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svbfloat16_t vls_bfloat16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat32_t vls_float32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat64_t vls_float64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); + +#define ptrue svptrue_b8() +#define ZERO_S8 svdup_n_s8(0) +#define ZERO_S16 svdup_n_s16(0) +#define ZERO_S32 svdup_n_s32(0) +#define ZERO_S64 svdup_n_s64(0) +#define ZERO_U8 svdup_n_u8(0) +#define ZERO_U16 svdup_n_u16(0) +#define ZERO_U32 svdup_n_u32(0) +#define ZERO_U64 svdup_n_u64(0) +#define ZERO_F16 svdup_n_f16(0.f) +#define ZERO_F32 svdup_n_f32(0.f) +#define ZERO_F64 svdup_n_f64(0.0) +#define ONE_S8 svdup_n_s8(1) +#define ONE_S16 svdup_n_s16(1) +#define ONE_S32 svdup_n_s32(1) +#define ONE_S64 svdup_n_s64(1) +#define ONE_U8 svdup_n_u8(1) +#define ONE_U16 svdup_n_u16(1) +#define ONE_U32 svdup_n_u32(1) +#define ONE_U64 svdup_n_u64(1) +#define ONE_F16 svdup_n_f16(1.f) +#define ONE_BF16 svdup_n_bf16(1.f) +#define ONE_F32 svdup_n_f32(1.f) +#define ONE_F64 svdup_n_f64(1.0) +#define ALL_S8_TRUE_MASK svdup_n_s8(0xff) +#define ALL_S8_FALSE_MASK svdup_n_s8(0x0) +#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff) +#define ALL_S16_FALSE_MASK svdup_n_s16(0x0) +#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff) +#define ALL_S32_FALSE_MASK svdup_n_s32(0x0) +#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff) +#define ALL_S64_FALSE_MASK svdup_n_s64(0x0) +#define ALL_U8_TRUE_MASK svdup_n_u8(0x01) +#define ALL_U8_FALSE_MASK svdup_n_u8(0x00) +#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK) +#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK) +#define ALL_BF16_TRUE_MASK svreinterpret_bf16_s16(ALL_S16_TRUE_MASK) +#define ALL_BF16_FALSE_MASK svreinterpret_bf16_s16(ALL_S16_FALSE_MASK) +#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK) +#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK) +#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK) +#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK) + +#endif // defined(CPU_CAPABILITY_SVE) diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..7f05c2ad166f60deed1464d4db3580e99df6a24e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h @@ -0,0 +1,580 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_bfloat16_t values; + + public: + using value_type = BFloat16; + using size_type = int; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(BFloat16); + } + + Vectorized() {} + Vectorized(svbfloat16_t v) : values(v) {} + Vectorized(int val); + Vectorized(BFloat16 val); + + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ BFloat16 buffer[size()] = {vals...}; + values = svld1_bf16(ptrue, reinterpret_cast(buffer)); + } + + operator svbfloat16_t() const { + return values; + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_), ALL_S16_TRUE_MASK); + return svsel_bf16(mask, b, a); + } + template + static Vectorized arange( + BFloat16 base = 0.f, + step_t step = static_cast(1)) { + __at_align__ BFloat16 buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_bf16(ptrue, reinterpret_cast(buffer)); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_bf16(svwhilelt_b16(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_bf16(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b16(0ull, count); + return svld1_bf16(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + __at_align__ bfloat16_t tmp[size()]; + std::memset(tmp, 0, sizeof(tmp)); + if (count == size()) { + svst1_bf16(ptrue, reinterpret_cast(tmp), values); + } else { + svbool_t pg = svwhilelt_b16(0ull, count); + svst1_bf16(pg, reinterpret_cast(tmp), values); + } + std::memcpy( + reinterpret_cast(ptr), + reinterpret_cast(tmp), + count * sizeof(bfloat16_t)); + } + const BFloat16& operator[](int idx) const = delete; + BFloat16& operator[](int idx) = delete; + int64_t zero_mask() const { + int64_t mask = 0; + // returns an integer mask where all zero elements are translated to + // 1-bit and others are translated to 0-bit int64_t mask = 0; + __at_align__ int16_t mask_array[size()]; + + svbool_t svbool_mask = + svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16); + svst1_s16( + ptrue, + mask_array, + svsel_s16(svbool_mask, ALL_S16_TRUE_MASK, ALL_S16_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const; + bool has_inf_nan() const; + Vectorized map(BFloat16 (*f)(BFloat16)) const { + __at_align__ BFloat16 tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = svdup_n_u16(0x7FFF); + auto vals = svreinterpret_u16_bf16(values); + vals = svand_u16_x(ptrue, vals, mask); + return svreinterpret_bf16_u16(vals); + } + Vectorized angle() const; + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const; + Vectorized acosh() const; + Vectorized asin() const; + Vectorized atan() const; + Vectorized atanh() const; + Vectorized atan2(const Vectorized& b) const; + Vectorized copysign(const Vectorized& sign) const; + Vectorized erf() const; + Vectorized erfc() const; + Vectorized erfinv() const; + Vectorized exp() const; + Vectorized exp2() const; + Vectorized expm1() const; + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const; + Vectorized hypot(const Vectorized& b) const; + Vectorized i0() const; + Vectorized i0e() const; + Vectorized digamma() const; + Vectorized igamma(const Vectorized& x) const; + Vectorized igammac(const Vectorized& x) const; + Vectorized nextafter(const Vectorized& b) const; + Vectorized log() const; + Vectorized log2() const; + Vectorized log10() const; + Vectorized log1p() const; + Vectorized frac() const; + Vectorized sin() const; + Vectorized sinh() const; + Vectorized cos() const; + Vectorized cosh() const; + Vectorized ceil() const; + Vectorized floor() const; + Vectorized neg() const { + auto mask = svdup_n_u16(0x8000); + auto vals = svreinterpret_u16_bf16(values); + vals = sveor_u16_x(ptrue, vals, mask); + return svreinterpret_bf16_u16(vals); + }; + Vectorized round() const; + Vectorized tan() const; + Vectorized tanh() const; + Vectorized trunc() const; + Vectorized lgamma() const; + Vectorized sqrt() const; + Vectorized reciprocal() const; + Vectorized rsqrt() const; + Vectorized pow(const Vectorized& b) const; + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const; + + Vectorized operator!=(const Vectorized& other) const; + + Vectorized operator<(const Vectorized& other) const; + + Vectorized operator<=(const Vectorized& other) const; + + Vectorized operator>(const Vectorized& other) const; + + Vectorized operator>=(const Vectorized& other) const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f)); + auto bf16_vec1 = svzip1_bf16(zero, a); + auto bf16_vec2 = svzip2_bf16(zero, a); + auto x1 = svreinterpret_f32_bf16(bf16_vec1); + auto x2 = svreinterpret_f32_bf16(bf16_vec2); + return {Vectorized(x1), Vectorized(x2)}; +} + +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a); + svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b); + return Vectorized(svuzp1_bf16(x1, x2)); +} + +inline void load_fp32_from_bf16(const BFloat16* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_bf16( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + Vectorized bf16_vec = Vectorized::loadu(data); + auto floats = convert_bfloat16_float(bf16_vec); + out1 = std::get<0>(floats); + out2 = std::get<1>(floats); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return convert_float_bfloat16( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::plus>(), a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::minus>(), a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::multiplies>(), a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::divides>(), a, b); +} + +inline Vectorized::Vectorized(int val) { + auto vals_f = svdup_n_f32(val); + values = convert_float_bfloat16(vals_f, vals_f); +} + +inline Vectorized::Vectorized(BFloat16 val) { + auto vals_f = svdup_n_f32((float)val); + values = convert_float_bfloat16(vals_f, vals_f); +} + +bool inline Vectorized::has_inf_nan() const { + auto [v1, v2] = convert_bfloat16_float(values); + return v1.has_inf_nan() || v2.has_inf_nan(); +} +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +#define DEFINE_BF16_FUNC_VIA_FLOAT(func_name) \ + Vectorized inline Vectorized::func_name() const { \ + auto [v1, v2] = convert_bfloat16_float(*this); \ + v1 = v1.func_name(); \ + v2 = v2.func_name(); \ + return convert_float_bfloat16(v1, v2); \ + } + +#define DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(func_name) \ + Vectorized inline Vectorized::func_name( \ + const Vectorized& a) const { \ + auto [v1, v2] = convert_bfloat16_float(*this); \ + auto [v3, v4] = convert_bfloat16_float(a); \ + v1 = v1.func_name(v3); \ + v2 = v2.func_name(v4); \ + return convert_float_bfloat16(v1, v2); \ + } + +DEFINE_BF16_FUNC_VIA_FLOAT(isnan); +DEFINE_BF16_FUNC_VIA_FLOAT(angle); +DEFINE_BF16_FUNC_VIA_FLOAT(acos); +DEFINE_BF16_FUNC_VIA_FLOAT(acosh); +DEFINE_BF16_FUNC_VIA_FLOAT(asin); +DEFINE_BF16_FUNC_VIA_FLOAT(atan); +DEFINE_BF16_FUNC_VIA_FLOAT(atanh); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign); +DEFINE_BF16_FUNC_VIA_FLOAT(erf); +DEFINE_BF16_FUNC_VIA_FLOAT(erfc); +DEFINE_BF16_FUNC_VIA_FLOAT(exp); +DEFINE_BF16_FUNC_VIA_FLOAT(exp2); +DEFINE_BF16_FUNC_VIA_FLOAT(expm1); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot); +DEFINE_BF16_FUNC_VIA_FLOAT(i0); +DEFINE_BF16_FUNC_VIA_FLOAT(i0e); +DEFINE_BF16_FUNC_VIA_FLOAT(digamma); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter); +DEFINE_BF16_FUNC_VIA_FLOAT(log); +DEFINE_BF16_FUNC_VIA_FLOAT(log2); +DEFINE_BF16_FUNC_VIA_FLOAT(log10); +DEFINE_BF16_FUNC_VIA_FLOAT(log1p); +DEFINE_BF16_FUNC_VIA_FLOAT(sin); +DEFINE_BF16_FUNC_VIA_FLOAT(sinh); +DEFINE_BF16_FUNC_VIA_FLOAT(cos); +DEFINE_BF16_FUNC_VIA_FLOAT(cosh); +DEFINE_BF16_FUNC_VIA_FLOAT(ceil); +DEFINE_BF16_FUNC_VIA_FLOAT(floor); +DEFINE_BF16_FUNC_VIA_FLOAT(round); +DEFINE_BF16_FUNC_VIA_FLOAT(tan); +DEFINE_BF16_FUNC_VIA_FLOAT(tanh); +DEFINE_BF16_FUNC_VIA_FLOAT(trunc); +DEFINE_BF16_FUNC_VIA_FLOAT(lgamma); +DEFINE_BF16_FUNC_VIA_FLOAT(sqrt); +DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal); +DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt); +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow); + +Vectorized inline Vectorized::operator==( + const Vectorized& other) const { + auto [f1, f2] = convert_bfloat16_float(values); + auto [f3, f4] = convert_bfloat16_float(other); + svbool_t mask1 = svcmpeq_f32(ptrue, f1, f3); + svbool_t mask2 = svcmpeq_f32(ptrue, f2, f4); + auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + + auto bf16_1 = svreinterpret_bf16_f32(res1); + auto bf16_2 = svreinterpret_bf16_f32(res2); + return svuzp1_bf16(bf16_1, bf16_2); +} +Vectorized inline Vectorized::operator!=( + const Vectorized& other) const { + auto [f1, f2] = convert_bfloat16_float(values); + auto [f3, f4] = convert_bfloat16_float(other); + svbool_t mask1 = svcmpne_f32(ptrue, f1, f3); + svbool_t mask2 = svcmpne_f32(ptrue, f2, f4); + auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + + auto bf16_1 = svreinterpret_bf16_f32(res1); + auto bf16_2 = svreinterpret_bf16_f32(res2); + return svuzp1_bf16(bf16_1, bf16_2); +} +Vectorized inline Vectorized::operator>( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 > v3, v2 > v4); +} +Vectorized inline Vectorized::operator>=( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 >= v3, v2 >= v4); +} +Vectorized inline Vectorized::operator<( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 < v3, v2 < v4); +} +Vectorized inline Vectorized::operator<=( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 <= v3, v2 <= v4); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&clamp_max), + a, + max); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&clamp_min), + a, + min); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return clamp_min(clamp_max(a, max), min); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + svand_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + svorr_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + sveor_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_bf16( + ptrue, + const_cast(reinterpret_cast(dst)) + i, + svldnt1_bf16( + ptrue, + const_cast(reinterpret_cast(src)) + + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b16(i, n); + svst1_bf16( + pg, + const_cast(reinterpret_cast(dst)) + i, + svldnt1_bf16( + pg, + const_cast(reinterpret_cast(src)) + + i)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b + c; +} + +#endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h new file mode 100644 index 0000000000000000000000000000000000000000..69ed5d061bd8859d1a12dc01d7873f0de5343b22 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h @@ -0,0 +1,236 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include + +#include +#include + +#if defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#include +#include +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \ + template <> \ + inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_##t1_prefix##_##t2_prefix(src); \ + } \ + template <> \ + inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_##t2_prefix##_##t1_prefix(src); \ + } + +DEFINE_SVE_CAST(int64_t, s64, double, f64) +DEFINE_SVE_CAST(int32_t, s32, double, f64) +DEFINE_SVE_CAST(int16_t, s16, double, f64) +DEFINE_SVE_CAST(int64_t, s64, float, f32) +DEFINE_SVE_CAST(int32_t, s32, float, f32) +DEFINE_SVE_CAST(int16_t, s16, float, f32) +DEFINE_SVE_CAST(float, f32, double, f64) + +#ifdef __ARM_FEATURE_BF16 +DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16) +DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16) +DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16) +#endif // __ARM_FEATURE_BF16 + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex_) { + svint64_t vindex = + svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svld1_gather_s64index_f64(ptrue, base_addr, vindex); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex_) { + svint32_t vindex = + svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svld1_gather_s32index_f32(ptrue, base_addr, vindex); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex_, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); + svint64_t vindex = + svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svsel_f64( + mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex_, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); + svint32_t vindex = + svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svsel_f32( + mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); + return svsub_s64_x( + ptrue, + svreinterpret_s64_f64(x), + svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return svcvt_s32_f32_x(ptrue, src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3} + // b = {b0, b1, b2, b3} + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized(svzip1_f64(a, b)), + Vectorized(svzip2_f64(a, b))); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair( + Vectorized(svzip1_f32(a, b)), Vectorized(svzip2_f32(a, b))); +} + +#ifdef __ARM_FEATURE_BF16 +template <> +std::pair< + Vectorized, + Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair( + Vectorized(svzip1_bf16(a, b)), + Vectorized(svzip2_bf16(a, b))); +} +#endif // __ARM_FEATURE_BF16 + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vectorized(svuzp1_f64(a, b)), + Vectorized(svuzp2_f64(a, b))); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + Vectorized(svuzp1_f32(a, b)), Vectorized(svuzp2_f32(a, b))); +} + +#ifdef __ARM_FEATURE_BF16 +template <> +std::pair< + Vectorized, + Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + Vectorized(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)), + Vectorized(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b))); +} +#endif // __ARM_FEATURE_BF16 + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h new file mode 100644 index 0000000000000000000000000000000000000000..9f73839995820aeee628d2e8d3f46ca8cc878086 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h @@ -0,0 +1,588 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_float64_t values; + + public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(double); + } + Vectorized() {} + Vectorized(svfloat64_t v) : values(v) {} + Vectorized(double val) { + values = svdup_n_f64(val); + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ double buffer[size()] = {vals...}; + values = svld1_f64(ptrue, buffer); + } + operator svfloat64_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each element is 1 if the corresponding bit in + // 'mask' is set, 0 otherwise. + __at_align__ int64_t flag_arr[size()]; + for (int i = 0; i < size(); i++) { + flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; + } + // Load the flag array into an SVE int64 vector. + svint64_t int_mask = svld1_s64(svptrue_b64(), flag_arr); + // Compare each lane of int_mask to 0; returns an svbool_t predicate where + // true indicates a nonzero flag. + svbool_t blend_mask = svcmpne_n_s64(svptrue_b64(), int_mask, 0); + + // Use svsel to select elements from b where the predicate is true, else + // from a. + svfloat64_t result = svsel(blend_mask, b.values, a.values); + return Vectorized(result); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); + return svsel_f64(mask, b, a); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + __at_align__ double buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f64(ptrue, buffer); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f64(svwhilelt_b64(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f64(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b64(0ull, count); + return svld1_f64(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f64(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b64(0ull, count); + svst1_f64(pg, reinterpret_cast(ptr), values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int64_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f64(ptrue, values, ZERO_F64); + svst1_s64( + ptrue, + mask_array, + svsel_s64(svbool_mask, ALL_S64_TRUE_MASK, ALL_S64_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f64(ptrue, values, ZERO_F64); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any( + ptrue, + svcmpuo_f64(ptrue, svsub_f64_x(ptrue, values, values), ZERO_F64)); + } + Vectorized map(double (*f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f64_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f64(NAN); + const auto nan_mask = svcmpuo_f64(ptrue, values, ZERO_F64); + const auto pi = svdup_n_f64(c10::pi); + + const auto neg_mask = svcmplt_f64(ptrue, values, ZERO_F64); + auto angle = svsel_f64(neg_mask, pi, ZERO_F64); + angle = svsel_f64(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosdx_u10sve(values)), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshdx_u10sve(values)), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asindx_u10sve(values)), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhdx_u10sve(values)), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atandx_u10sve(values)), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhdx_u10sve(values)), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2dx_u10sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized copysign(const Vectorized& sign) const { + USE_SLEEF( + { return Vectorized(Sleef_copysigndx_sve(values, sign)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erfdx_u10sve(values)), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcdx_u15sve(values)), map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expdx_u10sve(values)), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2dx_u10sve(values)), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1dx_u10sve(values)), map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmoddx_sve(values, q)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotdx_u05sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterdx_sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logdx_u10sve(values)), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2dx_u10sve(values)), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10dx_u10sve(values)), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pdx_u10sve(values)), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sindx_u10sve(values)), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhdx_u10sve(values)), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosdx_u10sve(values)), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshdx_u10sve(values)), map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f64_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f64_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f64_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f64_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tandx_u10sve(values)), map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( + Vectorized(Sleef_tanhdx_u10sve(values)), map(std::tanh)); + } + Vectorized trunc() const { + return svrintz_f64_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammadx_u10sve(values)), map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f64_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f64_x(ptrue, values, ONE_F64); + } + Vectorized rsqrt() const { + return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powdx_u10sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return svadd_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return svsub_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return svmul_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_f64_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return svmax_f64_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return svmin_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return svmin_f64_x(ptrue, max, svmax_f64_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return svmin_f64_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return svmax_f64_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + svand_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + svorr_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + sveor_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f64(ptrue, dst + i, svldnt1_f64(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b64(i, n); + svst1_f64(pg, dst + i, svldnt1_f64(pg, src + i)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmad_f64_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h new file mode 100644 index 0000000000000000000000000000000000000000..8008e670378f4d9ddc99da196ac8aceb697b8d31 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h @@ -0,0 +1,759 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_float32_t values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(float); + } + Vectorized() {} + Vectorized(svfloat32_t v) : values(v) {} + Vectorized(float val) { + values = svdup_n_f32(val); + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ float buffer[size()] = {vals...}; + values = svld1_f32(ptrue, buffer); + } + operator svfloat32_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each element is 1 if the corresponding bit in + // 'mask' is set, 0 otherwise. + __at_align__ int32_t flag_arr[size()]; + for (int i = 0; i < size(); i++) { + flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; + } + // Load the flag array into an SVE int32 vector. + svint32_t int_mask = svld1_s32(svptrue_b32(), flag_arr); + // Compare each lane of int_mask to 0; returns an svbool_t predicate where + // true indicates a nonzero flag. + svbool_t blend_mask = svcmpne_n_s32(svptrue_b32(), int_mask, 0); + // Use svsel to select elements from b where the predicate is true, else + // from a. + svfloat32_t result = svsel_f32(blend_mask, b.values, a.values); + return Vectorized(result); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); + return svsel_f32(mask, b, a); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + __at_align__ float buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f32(ptrue, buffer); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f32(svwhilelt_b32(0ull, count), b, a); + } + return b; + } + // Implementation is picked from + // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105 + inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const { + const auto c1 = + svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f + const auto c2 = + svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f + const auto c3 = + svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f + const auto c4 = + svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f + const auto c5 = + svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f + const auto shift = svreinterpret_f32_u32( + svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = svreinterpret_f32_u32( + svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32( + 0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32( + 0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 + // forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. + // n) + 127 will occupy the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part + // of x / ln(2) (i.e. n) because the decimal part has been pushed out + // and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be + // used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = svmla_f32_z(pg, shift, x, inv_ln2); + const auto n = svsub_f32_z(pg, z, shift); + const auto scale = svreinterpret_f32_u32( + svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy + // beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in + // term of accuracy and performance. + const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = svmul_f32_z(pg, r, r); + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + auto poly = svmla_f32_z(pg, scale, p12345, scale); + // Handle underflow and overflow. + poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); + return poly; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_f32(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f32(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b32(0ull, count); + svst1_f32(pg, reinterpret_cast(ptr), values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int32_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32); + svst1_s32( + ptrue, + mask_array, + svsel_s32(svbool_mask, ALL_S32_TRUE_MASK, ALL_S32_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any( + ptrue, + svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32)); + } + Vectorized map(float (*f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f32_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f32(NAN); + const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32); + const auto pi = svdup_n_f32(c10::pi); + + const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32); + auto angle = svsel_f32(neg_mask, pi, ZERO_F32); + angle = svsel_f32(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosfx_u10sve(values)), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshfx_u10sve(values)), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asinfx_u10sve(values)), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhfx_u10sve(values)), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atanfx_u10sve(values)), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhfx_u10sve(values)), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2fx_u10sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized copysign(const Vectorized& sign) const { + + USE_SLEEF( + { return Vectorized(Sleef_copysignfx_sve(values, sign)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erffx_u10sve(values)), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcfx_u15sve(values)), map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expfx_u10sve(values)), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2fx_u10sve(values)), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1fx_u10sve(values)), map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmodfx_sve(values, q)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotfx_u05sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterfx_sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logfx_u10sve(values)), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2fx_u10sve(values)), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10fx_u10sve(values)), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pfx_u10sve(values)), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sinfx_u10sve(values)), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhfx_u10sve(values)), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosfx_u10sve(values)), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshfx_u10sve(values)), map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f32_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f32_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f32_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f32_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tanfx_u10sve(values)), map(std::tan)); + } + // Implementation is picked from + // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179 + Vectorized tanh() const { + // Constants used for the tanh calculation. + const svfloat32_t CONST_1 = + svdup_n_f32(1.f); // Constant 1.0f for the tanh formula. + const svfloat32_t CONST_2 = svdup_n_f32( + 2.f); // Constant 2.0f for the tanh formula (used in exp(2x)). + const svfloat32_t CONST_MIN_TANH = svdup_n_f32( + -10.f); // Minimum threshold for input values to prevent overflow. + const svfloat32_t CONST_MAX_TANH = svdup_n_f32( + 10.f); // Maximum threshold for input values to prevent overflow. + + // Step 1: Clamp the values within the range [-10, 10] to prevent overflow + // during exponentiation. The tanh function approaches ±1 rapidly as the + // input grows large, so we limit the input range to avoid numerical + // instability. svmax_f32_z ensures values are greater than -10, and + // svmin_f32_z ensures they are less than 10. + svfloat32_t x = svmin_f32_z( + ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); + + // Step 2: Calculate exp(2 * x), where x is the clamped value. + // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of + // the result. + svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x)); + + // Step 3: Calculate the numerator of the tanh function, which is exp(2x) + // - 1. + svfloat32_t num = svsub_f32_z(ptrue, exp2x, CONST_1); + + // Step 4: Calculate the denominator of the tanh function, which is exp(2x) + // + 1. + svfloat32_t den = svadd_f32_z(ptrue, exp2x, CONST_1); + + // Step 5: Calculate the tanh function as the ratio of the numerator and + // denominator: num / den. + svfloat32_t tanh = svdiv_f32_z(ptrue, num, den); + + // Return the calculated tanh values. + return tanh; + } + Vectorized trunc() const { + return svrintz_f32_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammafx_u10sve(values)), map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f32_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f32_x(ptrue, values, ONE_F32); + } + Vectorized rsqrt() const { + return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powfx_u10sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return svadd_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return svsub_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return svmul_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_f32_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return svmax_f32_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return svmin_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return svmin_f32_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return svmax_f32_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b32(i, n); + svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i)); + } +} + +template <> +inline void convert(const float* src, at::Half* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svuzp1_f16( + svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svuzp1_f16( + svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +} + +template <> +inline void convert(const at::Half* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svzip1_f16( + svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svzip1_f16( + svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +} + +template <> +inline void convert(const bool* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmad_f32_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h new file mode 100644 index 0000000000000000000000000000000000000000..82d2271f05ca9b3a8c21731b0ad409e4ccb8edf2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h @@ -0,0 +1,497 @@ +#pragma once + +#include +#include +#include + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +#define VEC_INT_SVE_TEMPLATE(vl, bit) \ + template <> \ + struct is_vec_specialized_for : std::bool_constant {}; \ + \ + template <> \ + class Vectorized { \ + private: \ + vls_int##bit##_t values; \ + \ + public: \ + using value_type = int##bit##_t; \ + using size_type = int; \ + static constexpr size_type size() { \ + return vl; \ + } \ + Vectorized() {} \ + Vectorized(svint##bit##_t v) : values(v) {} \ + Vectorized(int##bit##_t val) { \ + values = svdup_n_s##bit(val); \ + } \ + template < \ + typename... Args, \ + typename = std::enable_if_t<(sizeof...(Args) == size())>> \ + Vectorized(Args... vals) { \ + __at_align__ int##bit##_t buffer[size()] = {vals...}; \ + values = svld1_s##bit(ptrue, buffer); \ + } \ + operator svint##bit##_t() const { \ + return values; \ + } \ + template \ + static Vectorized blend( \ + const Vectorized& a, \ + const Vectorized& b) { \ + __at_align__ int##bit##_t flag_arr[size()]; \ + for (int i = 0; i < size(); ++i) { \ + flag_arr[i] = (i < 64 && (mask & (1ULL << i))) ? 1 : 0; \ + } \ + svbool_t blend_mask = svcmpne_n_s##bit( \ + svptrue_b##bit(), svld1_s##bit(svptrue_b##bit(), flag_arr), 0); \ + return Vectorized( \ + svsel_s##bit(blend_mask, b.values, a.values)); \ + } \ + static Vectorized blendv( \ + const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& mask_) { \ + svbool_t mask = svcmpeq_s##bit(ptrue, mask_, ALL_S##bit##_TRUE_MASK); \ + return svsel_s##bit(mask, b, a); \ + } \ + /* step sometimes requires a higher precision type (e.g., T=int, \ + * step_t=double) */ \ + template \ + static Vectorized arange( \ + int##bit##_t base = 0, \ + step_t step = static_cast(1)) { \ + __at_align__ int##bit##_t buffer[size()]; \ + for (int64_t i = 0; i < size(); i++) { \ + buffer[i] = base + i * step; \ + } \ + return svld1_s##bit(ptrue, buffer); \ + } \ + static Vectorized set( \ + const Vectorized& a, \ + const Vectorized& b, \ + int##bit##_t count = size()) { \ + if (count == 0) { \ + return a; \ + } else if (count < size()) { \ + return svsel_s##bit(svwhilelt_b##bit(0ull, count), b, a); \ + } \ + return b; \ + } \ + static Vectorized loadu( \ + const void* ptr, \ + int64_t count = size()) { \ + if (count == size()) \ + return svld1_s##bit( \ + ptrue, reinterpret_cast(ptr)); \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + return svld1_s##bit(pg, reinterpret_cast(ptr)); \ + } \ + void store(void* ptr, int64_t count = size()) const { \ + if (count == size()) { \ + svst1_s##bit(ptrue, reinterpret_cast(ptr), values); \ + } else { \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + svst1_s##bit(pg, reinterpret_cast(ptr), values); \ + } \ + } \ + const int##bit##_t& operator[](int idx) const = delete; \ + int##bit##_t& operator[](int idx) = delete; \ + Vectorized abs() const { \ + return svabs_s##bit##_x(ptrue, values); \ + } \ + Vectorized real() const { \ + return values; \ + } \ + Vectorized imag() const { \ + return svdup_n_s##bit(0); \ + } \ + Vectorized conj() const { \ + return values; \ + } \ + Vectorized frac() const; \ + Vectorized neg() const { \ + return svneg_s##bit##_x(ptrue, values); \ + } \ + Vectorized operator==( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpeq_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator!=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpne_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<( \ + const Vectorized& other) const { \ + svbool_t mask = svcmplt_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmple_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpgt_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpge_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized eq(const Vectorized& other) const; \ + Vectorized ne(const Vectorized& other) const; \ + Vectorized gt(const Vectorized& other) const; \ + Vectorized ge(const Vectorized& other) const; \ + Vectorized lt(const Vectorized& other) const; \ + Vectorized le(const Vectorized& other) const; \ + }; \ + template <> \ + Vectorized inline operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return svadd_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return svsub_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return svmul_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline maximum( \ + const Vectorized& a, const Vectorized& b) { \ + return svmax_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline minimum( \ + const Vectorized& a, const Vectorized& b) { \ + return svmin_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, svmax_s##bit##_x(ptrue, min, a)); \ + } \ + template <> \ + Vectorized inline clamp_max( \ + const Vectorized& a, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, a); \ + } \ + template <> \ + Vectorized inline clamp_min( \ + const Vectorized& a, \ + const Vectorized& min) { \ + return svmax_s##bit##_x(ptrue, min, a); \ + } \ + template <> \ + Vectorized inline operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return svand_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return svorr_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return sveor_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + inline Vectorized operator~( \ + const Vectorized& a) { \ + return sveor_s##bit##_x(ptrue, a, svdup_n_s##bit(-1)); \ + } \ + Vectorized inline Vectorized::eq( \ + const Vectorized& other) const { \ + return (*this == other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ne( \ + const Vectorized& other) const { \ + return (*this != other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::gt( \ + const Vectorized& other) const { \ + return (*this > other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ge( \ + const Vectorized& other) const { \ + return (*this >= other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::lt( \ + const Vectorized& other) const { \ + return (*this < other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::le( \ + const Vectorized& other) const { \ + return (*this <= other) & Vectorized(1); \ + } + +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int64_t), 64) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int32_t), 32) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int16_t), 16) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int8_t), 8) + +template +Vectorized inline intdiv_nosve( + const Vectorized& a, + const Vectorized& b) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] /= values_b[i]; + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_s64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_s32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +inline void convert(const int32_t* src, int64_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); + } +} + +template <> +inline void convert(const int64_t* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = + svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = + svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b32(i, n); + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +} + +template <> +inline void convert(const bool* src, int64_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = + svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_64 = svwhilelt_b64(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = + svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +} + +template <> +inline void convert(const bool* src, int32_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +} + +template <> +inline void convert(const uint8_t* src, bool* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b8(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8( + pg, + reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b8(i, n); + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8( + pg, + reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..61cb63cb1e12ad3e4f3c67dbfb910c7cfe00f4c8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h @@ -0,0 +1,606 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include +#include +#include +#include +#include +#include + +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with SVE. This may not be an issue, because +// currently for quantization we assume the user has at least SVE +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (SVE+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + using size_type = int; + static constexpr size_type size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size() / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (size_t i = 0; i < size(); ++i) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = at::native::dequantize_val( + tmp_scale[j], + tmp_zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = at::native::dequantize_val( + tmp_scale[j], + tmp_zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return svld1_s32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_s32(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (size_t i = 0; i < size(); ++i) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = + nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return svld1_u8(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b8(0ull, count); + return svld1_u8(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h new file mode 100644 index 0000000000000000000000000000000000000000..0bfe65cd195908038aaedfa66f2a4af2c8dbb838 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h @@ -0,0 +1,57 @@ +#pragma once + +#if defined(CPU_CAPABILITY_AVX512) +#include +#else +#include +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline Vectorized convert_to_bool(Vectorized x) { + __at_align__ bool buffer[x.size()]; + x.ne(Vectorized(0)).store(buffer); + + Vectorized ret; + static_assert(x.size() == ret.size()); + std::memcpy(ret, buffer, ret.size() * sizeof(bool)); + return ret; +} + +template <> +inline Vectorized Vectorized::loadu(const void* ptr) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr)); +} + +template <> +inline Vectorized Vectorized::loadu( + const void* ptr, + int64_t count) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr, count)); +} + +template +struct VecHoldType { + using hold_type = typename VT::value_type; +}; + +template <> +struct VecHoldType> { + using hold_type = BFloat16; +}; + +template <> +struct VecHoldType> { + using hold_type = Half; +}; + +template +using vechold_type = typename VecHoldType::hold_type; + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h new file mode 100644 index 0000000000000000000000000000000000000000..c49580410aaf421642265e4e62a489ffe92720e2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h @@ -0,0 +1,14 @@ +#pragma once +// ARM NEON uses 128-bit vector registers. + +#include + +#ifdef __aarch64__ +#if !defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#endif + +#include +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..e533f5931ad9f6e61acec3518635356313aba018 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h @@ -0,0 +1,567 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] +#include +#include +#include +#include +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Following vec128_half_neon.h, we only support aarch64. +#if !defined(C10_MOBILE) && defined(__aarch64__) +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +// Unlike the float16_t family of types, bfloat16_t is not available +// when we're not targeting bfloat16 hardware support on some +// platforms (but not Mac, so we have to be careful not to shadow the +// definitions in case they are actually there!). (See +// https://godbolt.org/z/orv6e94n4 ) So, we need to handle it as +// uint16_t in that case. +#define IMPLEMENT_AT_BF16_SHIM(vec_suffix) \ + inline at_bfloat16x4_t at_vget_low_bf16(at_bfloat16x8_t a) { \ + return vget_low_##vec_suffix(a); \ + } \ + \ + inline at_bfloat16x4_t at_vget_high_bf16(at_bfloat16x8_t a) { \ + return vget_high_##vec_suffix(a); \ + } \ + \ + inline at_bfloat16x8_t at_vcombine_bf16( \ + at_bfloat16x4_t low, at_bfloat16x4_t high) { \ + return vcombine_##vec_suffix(low, high); \ + } \ + \ + inline at_bfloat16x8_t at_vdupq_n_bf16(at_bfloat16_t value) { \ + return vdupq_n_##vec_suffix(value); \ + } \ + \ + inline at_bfloat16x8_t at_vld1q_bf16(const at_bfloat16_t* ptr) { \ + return vld1q_##vec_suffix(ptr); \ + } \ + \ + inline void at_vst1q_bf16(at_bfloat16_t* ptr, at_bfloat16x8_t value) { \ + vst1q_##vec_suffix(ptr, value); \ + } \ + \ + template \ + inline at_bfloat16x8_t at_vreinterpretq_bf16_u16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpretq_bf16_u16(val); \ + } \ + } \ + template \ + inline at_bfloat16x4_t at_vreinterpret_bf16_u16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpret_bf16_u16(val); \ + } \ + } \ + template \ + inline uint16x8_t at_vreinterpretq_u16_bf16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpretq_u16_bf16(val); \ + } \ + } \ + template \ + inline uint16x4_t at_vreinterpret_u16_bf16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpret_u16_bf16(val); \ + } \ + } + +#ifdef __ARM_FEATURE_BF16 +using at_bfloat16x8_t = bfloat16x8_t; +using at_bfloat16x4_t = bfloat16x4_t; +using at_bfloat16_t = bfloat16_t; +IMPLEMENT_AT_BF16_SHIM(bf16) +#define at_vsetq_lane_bf16 vsetq_lane_bf16 +#define at_vgetq_lane_bf16 vgetq_lane_bf16 +#else +using at_bfloat16x8_t = uint16x8_t; +using at_bfloat16x4_t = uint16x4_t; +using at_bfloat16_t = uint16_t; +IMPLEMENT_AT_BF16_SHIM(u16) +#define at_vsetq_lane_bf16 vsetq_lane_u16 +#define at_vgetq_lane_bf16 vgetq_lane_u16 +#endif // __ARM_FEATURE_BF16 + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res); +}; + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res) { + return at_vsetq_lane_bf16(at_vgetq_lane_bf16(b, index), res, index); + } +}; + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res) { + return at_vsetq_lane_bf16(at_vgetq_lane_bf16(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16< + at_bfloat16x8_t, + c10::BFloat16, + BlendBFloat16Regs, + Vectorized> { + using Base = Vectorized16< + at_bfloat16x8_t, + c10::BFloat16, + BlendBFloat16Regs, + Vectorized>; + friend Base; + friend std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a); + friend Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b); + + private: + Vectorized map2( + const Vectorized& second, + c10::BFloat16 (*const f)(c10::BFloat16, c10::BFloat16)) const { + __at_align__ c10::BFloat16 tmp_first[size()]; + __at_align__ c10::BFloat16 tmp_second[size()]; + store(tmp_first); // store this to tmp_first + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp_first[i] = f(tmp_first[i], tmp_second[i]); + } + return loadu(tmp_first); + } + + static float32x4_t convert_f32_bf16(at_bfloat16x4_t bf16) { +#ifdef __ARM_FEATURE_BF16 + return vcvt_f32_bf16(bf16); +#else + int32x4_t shift = vdupq_n_s32(16); + return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(bf16), shift)); +#endif // __ARM_FEATURE_BF16 + } + + static at_bfloat16x4_t convert_bf16_f32(const Vectorized& f32) { +#ifdef __ARM_FEATURE_BF16 + return vcvt_bf16_f32(f32); +#else + static_assert(std::is_same_v); + uint32x4_t as_uint32 = vreinterpretq_u32_f32(f32); + uint32x4_t rounding_bias = vaddq_u32( + vandq_u32(vshrq_n_u32(as_uint32, 16), vdupq_n_u32(1)), + vdupq_n_u32(0x7FFF)); + at_bfloat16x4_t rounded = + vshrn_n_u32(vaddq_u32(as_uint32, rounding_bias), 16); + const auto bf16_nan = vdup_n_u16(0x7FC0); + return vbsl_u16( + vmovn_u32(vreinterpretq_u32_f32(f32.isnan())), bf16_nan, rounded); +#endif // __ARM_FEATURE_BF16 + } + + Vectorized map_with_vec_float_method( + Vectorized (Vectorized::*m)() const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + Vectorized mv0 = (Vectorized(v00).*m)(); + Vectorized mv1 = (Vectorized(v01).*m)(); + at_bfloat16x4_t r00 = convert_bf16_f32(mv0); + at_bfloat16x4_t r01 = convert_bf16_f32(mv1); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + Vectorized map2_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); + float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); + Vectorized mv0 = (Vectorized(v00).*m)(second_v00); + Vectorized mv1 = (Vectorized(v01).*m)(second_v01); + at_bfloat16x4_t r00 = convert_bf16_f32(mv0); + at_bfloat16x4_t r01 = convert_bf16_f32(mv1); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + Vectorized map2_bitmask_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); + float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); + Vectorized mv0 = (Vectorized(v00).*m)(second_v00); + Vectorized mv1 = (Vectorized(v01).*m)(second_v01); + // Assume the operator returns a bitmask, not "real" floats, and + // just narrow the bits. All-ones is a NaN and will get mangled by + // conversion! + at_bfloat16x4_t r00 = + at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0))); + at_bfloat16x4_t r01 = + at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1))); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + public: + using Vectorized16::Vectorized16; + + Vectorized() = default; + + Vectorized(c10::BFloat16 val) + : Vectorized16(at_vdupq_n_bf16(c10::bit_cast(val.x))) {} + Vectorized(float val) : Vectorized(c10::BFloat16(val)) {} + Vectorized( + value_type val0, + value_type val1, + value_type val2, + value_type val3, + value_type val4, + value_type val5, + value_type val6, + value_type val7) + : Vectorized16(at_bfloat16x8_t{ + c10::bit_cast(val0.x), + c10::bit_cast(val1.x), + c10::bit_cast(val2.x), + c10::bit_cast(val3.x), + c10::bit_cast(val4.x), + c10::bit_cast(val5.x), + c10::bit_cast(val6.x), + c10::bit_cast(val7.x)}) {} + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // NOTE: blendv has the same problems as it does for Half; see comments in + // vec128_half_neon.h. + Vectorized vec(mask.values); + vec.values = at_vreinterpretq_bf16_u16(vbslq_u16( + at_vreinterpretq_u16_bf16(vec.values), + at_vreinterpretq_u16_bf16(b.values), + at_vreinterpretq_u16_bf16(a.values))); + return vec; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + uint16_t pre_mask[size()] = {0}; + for (int i = 0; i < count; i++) { + pre_mask[i] = 0xFFFF; + } + uint16x8_t mask = vld1q_u16(pre_mask); + + Vectorized vec(at_vreinterpretq_bf16_u16(vbslq_u16( + mask, + at_vreinterpretq_u16_bf16(b.values), + at_vreinterpretq_u16_bf16(a.values)))); + + return vec; + } + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) { + return at_vld1q_bf16(reinterpret_cast(ptr)); + } + __at_align__ at_bfloat16_t tmp_values[size()]; + std::memset(tmp_values, 0, sizeof(tmp_values)); + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(at_bfloat16_t)); + return at_vld1q_bf16(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + at_vst1q_bf16(reinterpret_cast(ptr), values); + return; + } else { + at_bfloat16_t tmp_values[size()]; + at_vst1q_bf16(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(at_bfloat16_t)); + } + } + Vectorized isnan() const { + // NOTE: we could make this faster by doing vectorized checks of + // exponent/payload bits. + __at_align__ c10::BFloat16 tmp[size()]; + __at_align__ c10::BFloat16 res[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i])) { + std::memset(static_cast(&res[i]), 0xFF, sizeof(c10::BFloat16)); + } else { + std::memset(static_cast(&res[i]), 0, sizeof(c10::BFloat16)); + } + } + return loadu(res); + } + bool has_inf_nan() const { + __at_align__ c10::BFloat16 tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } +#define DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(name) \ + Vectorized name() const { \ + return map_with_vec_float_method(&Vectorized::name); \ + } + +#define DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(name) \ + Vectorized name(const Vectorized& other) const { \ + return map2_bitmask_with_vec_float_method( \ + other, &Vectorized::name); \ + } + + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) + Vectorized frac() const; + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) + +#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD +#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; // Vectorized + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + at_bfloat16x8_t x = a; + float32x4_t x1 = + Vectorized::convert_f32_bf16(at_vget_low_bf16(x)); + float32x4_t x2 = + Vectorized::convert_f32_bf16(at_vget_high_bf16(x)); + return {Vectorized(x1), Vectorized(x2)}; +} +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + at_bfloat16x4_t x1 = Vectorized::convert_bf16_f32(a); + at_bfloat16x4_t x2 = Vectorized::convert_bf16_f32(b); + return Vectorized(at_vcombine_bf16(x1, x2)); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return convert_float_bfloat16( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::plus>(), a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::minus>(), a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::multiplies>(), a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::divides>(), a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + vandq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + vorrq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + veorq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also, + // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered + // elements, not the bottom and top half, so they don't seem + // particularly useful here. Ideally we would include dot product in + // the Vectorized interface... + return a * b + c; +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + // See NOTE [BF16 FMA] above. + return a * b - c; +} + +#endif // !defined(C10_MOBILE) && defined(__aarch64__) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..0ad0c892b06c06ab4cbf432c15f5603f4c7b5bf0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h @@ -0,0 +1,65 @@ +#pragma once +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_half_register_to_float(src[0]); + } +}; +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + const auto [v0, v1] = convert_int8_to_float(src[0]); + return VectorizedN(v0, v1); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + uint16x8_t u16_8 = vld1q_u16(reinterpret_cast(&src[0])); + auto u16_low1 = vget_low_u16(u16_8); + auto u16_high1 = vget_high_u16(u16_8); + float32x4_t f32x4_0 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_low1), 16)); + float32x4_t f32x4_1 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_high1), 16)); + result[0] = f32x4_0; + result[1] = f32x4_1; + return result; + } +}; +// Half register to full register. +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + uint16x4_t u16_8 = vld1_u16(reinterpret_cast(&src[0])); + float32x4_t f32x4_0 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_8), 16)); + result[0] = f32x4_0; + return result; + } +}; + +#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..4c6e21b28c2c53b76c892acb64dcfe8ad67f0522 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -0,0 +1,631 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#endif + +// Sleef offers vectorized versions of some transcedentals +// such as sin, cos, tan etc.. +// However for now opting for STL, since we are not building +// with Sleef for mobile yet. + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Right now contains only aarch64 implementation. +// Due to follow two reasons aarch32 is not currently supported. +// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics +// that work for aarch64 dont work for aarch32. +// 2. Android NDK r21 has problems with compiling aarch32. +// Clang seg faults. +// https://github.com/android/ndk/issues/1248 +// https://bugs.llvm.org/show_bug.cgi?id=45824 +// Most likely we will do aarch32 support with inline asm. +#if defined(__aarch64__) + +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res); +}; + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index); + } +}; + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + float32x4_t values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(float32x4_t v) : values(v) {} + Vectorized(float val) : values{vdupq_n_f32(val)} {} + Vectorized(float val0, float val1, float val2, float val3) + : values{val0, val1, val2, val3} {} + Vectorized(float (&arr)[4]) : Vectorized(arr[0], arr[1], arr[2], arr[3]) {} + operator float32x4_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + Vectorized vec; + vec.values = BlendRegs < 0, + (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 1, + (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 2, + (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 3, + (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values); + return vec; + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // TODO + // NB: This requires that each value, i.e., each uint value, + // of the mask either all be zeros or all be 1s. + // We perhaps need some kind of an assert? + // But that will affect performance. + Vectorized vec(mask.values); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const Vectorized step_sizes(0, 1, 2, 3); + return fmadd(step_sizes, step_vec, base_vec); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + case 2: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + case 3: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f32(reinterpret_cast(ptr)); + } else { + __at_align__ float tmp_values[size()]; + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(float)); + return vld1q_f32(reinterpret_cast(tmp_values)); + } + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f32(reinterpret_cast(ptr), values); + } else { + float tmp_values[size()]; + vst1q_f32(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + // Very slow implementation of indexing. + // Only required because vec256_qint refers to this. + // Once we specialize that implementation for ARM + // this should be removed. TODO (kimishpatel) + float operator[](int idx) const { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + float operator[](int idx) { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + // For boolean version where we want to if any 1/all zero + // etc. can be done faster in a different way. + int zero_mask() const { + __at_align__ float tmp[size()]; + store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0.f) { + mask |= (1 << i); + } + } + return mask; + } + Vectorized isnan() const { + return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values))); + } + bool has_inf_nan() const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized map2( + const Vectorized& second, + float (*const f)(float, float)) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_second[size()]; + store(tmp); + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i], tmp_second[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return Vectorized(vabsq_f32(values)); + } + Vectorized angle() const { + auto zero = Vectorized(0); + auto pi = Vectorized(c10::pi); + auto tmp = blendv(zero, pi, *this < zero); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return *this; + } +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, sleef_name) \ + Vectorized name() const { \ + return USE_SLEEF(Vectorized(sleef_name(values)), map(std::name)); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acosh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asinh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atanh) + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, sleef_name) \ + Vectorized name(const Vectorized& arg) const { \ + return USE_SLEEF( \ + Vectorized(sleef_name(values, arg.values)), \ + map2(arg, std::name)); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(atan2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + copysign, + Sleef_copysignf4) + Vectorized erf() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + erfc, + Sleef_erfcf4_u15) + Vectorized erfinv() const { + return map(calc_erfinv); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) + Vectorized exp_u20() const { + return exp(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + fmod, + Sleef_fmodf4) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + hypot, + Sleef_hypotf4_u05) + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + return map2(x, calc_igamma); + } + Vectorized igammac(const Vectorized& x) const { + return map2(x, calc_igammac); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log10) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log1p) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + nextafter, + Sleef_nextafterf4) + Vectorized frac() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sinh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cosh) + Vectorized ceil() const { + return map(at::native::ceil_impl); + } + Vectorized floor() const { + return map(at::native::floor_impl); + } + Vectorized neg() const { + return Vectorized(vnegq_f32(values)); + } + Vectorized round() const { + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. + return map(at::native::round_impl); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tanh) + Vectorized trunc() const { + return Vectorized(vrndq_f32(values)); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(lgamma) + Vectorized sqrt() const { + return Vectorized(vsqrtq_f32(values)); + } + Vectorized reciprocal() const { + return Vectorized(vdivq_f32(vdupq_n_f32(1.0f), values)); + } + Vectorized rsqrt() const { + return this->sqrt().reciprocal(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(pow) + Vectorized operator==(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vceqq_f32(values, other.values))); + } + + Vectorized operator!=(const Vectorized& other) const { + float32x4_t r0 = + vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, other.values))); + return Vectorized(r0); + } + + Vectorized operator<(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcltq_f32(values, other.values))); + } + + Vectorized operator<=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcleq_f32(values, other.values))); + } + + Vectorized operator>(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcgtq_f32(values, other.values))); + } + + Vectorized operator>=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcgeq_f32(values, other.values))); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vaddq_f32(a, b)); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vsubq_f32(a, b)); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vmulq_f32(a, b)); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vdivq_f32(a, b)); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vmaxq_f32(a, b)); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vminq_f32(a, b)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + vandq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + vorrq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + veorq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, int32_t* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vfmaq_f32(c, a, b)); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vnegq_f32(vfmsq_f32(c, a, b))); +} + +inline Vectorized Vectorized::erf() const { + // constants + const Vectorized neg_zero_vec(-0.f); + const Vectorized one_vec(1.0f); + const Vectorized p(0.3275911f); + const Vectorized p1(0.254829592f); + const Vectorized p2(-0.284496736f); + const Vectorized p3(1.421413741f); + const Vectorized p4(-1.453152027f); + const Vectorized p5(1.061405429f); + // sign(x) + auto sign_mask = neg_zero_vec & *this; + auto abs_vec = this->abs(); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = fmadd(p, abs_vec, one_vec); + auto t = one_vec / tmp0; + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = fmadd(p5, t, p4); + auto tmp2 = fmadd(tmp1, t, p3); + auto tmp3 = fmadd(tmp2, t, p2); + auto r = fmadd(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = (*this) * (*this); + auto neg_pow_2 = pow_2 ^ neg_zero_vec; + auto tmp4 = neg_pow_2.map( + std::exp); // This can be swapped for a faster implementation of exp. + auto tmp5 = tmp4 ^ neg_zero_vec; + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = t * tmp5; + auto tmp7 = fmadd(tmp6, r, one_vec); + return tmp7 ^ sign_mask; +} +#undef DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC +#undef DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC +#endif /* defined(aarch64) */ + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..e468244968c08d8ac01a06519a062e3ee334e8a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -0,0 +1,614 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#include +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Right now contains only aarch64 implementation. +// Due to follow two reasons aarch32 is not currently supported. +// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics +// that work for aarch64 dont work for aarch32. +// 2. Android NDK r21 has problems with compiling aarch32. +// Clang seg faults. +// https://github.com/android/ndk/issues/1248 +// https://bugs.llvm.org/show_bug.cgi?id=45824 +// Most likely we will do aarch32 support with inline asm. +#if !defined(C10_MOBILE) && defined(__aarch64__) + +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res); +}; + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res) { + return vsetq_lane_f16(vgetq_lane_f16(b, index), res, index); + } +}; + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res) { + return vsetq_lane_f16(vgetq_lane_f16(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +// On ARM, Half type supports float16_t->Half constructor and Half->float16_t +// conversion +template <> +class Vectorized : public Vectorized16< + float16x8_t, + c10::Half, + BlendHalfRegs, + Vectorized> { + using Base = Vectorized16< + float16x8_t, + c10::Half, + BlendHalfRegs, + Vectorized>; + friend Base; + + private: + // We use these private map functions to implement various methods + Vectorized map_with_vec_float_method( + Vectorized (Vectorized::*m)() const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + Vectorized mv0 = (Vectorized(v00).*m)(); + Vectorized mv1 = (Vectorized(v01).*m)(); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); + return Vectorized(vcombine_f16(r00, r01)); + } + + Vectorized map2_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values)); + float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values)); + Vectorized mv0 = + (Vectorized(v00).*m)(Vectorized(second_v00)); + Vectorized mv1 = + (Vectorized(v01).*m)(Vectorized(second_v01)); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); + + // Pack result into Vectorized + return Vectorized(vcombine_f16(r00, r01)); + } + + Vectorized map2_bitmask_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values)); + float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values)); + Vectorized mv0 = + (Vectorized(v00).*m)(Vectorized(second_v00)); + Vectorized mv1 = + (Vectorized(v01).*m)(Vectorized(second_v01)); + // Assume the operator returns a bitmask, not "real" floats, and + // just narrow the bits. All-ones is a NaN and will get mangled by + // conversion! + float16x4_t r00 = + vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0))); + float16x4_t r01 = + vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1))); + + // Pack result into Vectorized + return Vectorized(vcombine_f16(r00, r01)); + } + + public: + using Vectorized16::Vectorized16; + + Vectorized() = default; + + // A ctor that accepts c10::Half is needed to fit interface with vec_base.h + // A second constructor that takes float16_t is also included + Vectorized(c10::Half val) : Vectorized((float16_t)val) {} + Vectorized(float16_t val) : Vectorized16(vdupq_n_f16(val)) {} + Vectorized( + value_type val0, + value_type val1, + value_type val2, + value_type val3, + value_type val4, + value_type val5, + value_type val6, + value_type val7) + : Vectorized16( + float16x8_t{val0, val1, val2, val3, val4, val5, val6, val7}) {} + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // Note: using blendv is very awkward because 0xFFFF is one of + // many NaN's in FP16 It's unfortunate that the mask has type Half + // (required from vec_base) + + // TODO + // NB: This requires that each value, i.e., each uint value, + // of the mask either all be zeros or all be 1s. + // We perhaps need some kind of an assert? + // But that will affect performance. + + // NOTE [vbslq_f16]: vbslq_f16 doesn't work on clang without + // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC. vbslq_u16 generates the + // same instruction anyway. see https://godbolt.org/z/cY4a55Y7P + Vectorized vec(mask.values); + vec.values = vreinterpretq_f16_u16(vbslq_u16( + vreinterpretq_u16_f16(vec.values), + vreinterpretq_u16_f16(b.values), + vreinterpretq_u16_f16(a.values))); + return vec; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + uint16_t pre_mask[size()] = {0}; + for (int i = 0; i < count; i++) { + pre_mask[i] = 0xFFFF; + } + uint16x8_t mask = vld1q_u16(pre_mask); + + // Using blendv is awkward because 0xFFFF is one of many NaN's in FP16 + // so we directly use vbslq_u16 instead. (See NOTE [vbslq_f16] above.) + Vectorized vec(vreinterpretq_f16_u16(vbslq_u16( + mask, + vreinterpretq_u16_f16(b.values), + vreinterpretq_u16_f16(a.values)))); + + return vec; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f16(reinterpret_cast(ptr)); + } + __at_align__ float16_t tmp_values[size()]; + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(float16_t)); + return vld1q_f16(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f16(reinterpret_cast(ptr), values); + return; + } else { + float16_t tmp_values[size()]; + vst1q_f16(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float16_t)); + } + } + // For boolean version where we want to if any 1/all zero + // etc. can be done faster in a different way. + Vectorized isnan() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values))); +#else + // NOTE: we could make this faster by doing vectorized checks of + // exponent/payload bits. + __at_align__ c10::Half tmp[size()]; + __at_align__ c10::Half res[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i])) { + std::memset(static_cast(&res[i]), 0xFF, sizeof(c10::Half)); + } else { + std::memset(static_cast(&res[i]), 0, sizeof(c10::Half)); + } + } + return loadu(res); +#endif + } + bool has_inf_nan() const { + __at_align__ c10::Half tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } + Vectorized abs() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vabsq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::abs); +#endif + } + Vectorized frac() const; + Vectorized neg() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vnegq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::neg); +#endif + } + Vectorized trunc() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vrndq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::trunc); +#endif + } + Vectorized sqrt() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vsqrtq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::sqrt); +#endif + } + Vectorized reciprocal() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + auto ones = vdupq_n_f16(1.0f); + return Vectorized(vdivq_f16(ones, values)); +#else + return map_with_vec_float_method(&Vectorized::reciprocal); +#endif + } + Vectorized operator==(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vceqq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator==); +#endif + } + + Vectorized operator!=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, other.values)))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator!=); +#endif + } + + Vectorized operator<(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcltq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator<); +#endif + } + + Vectorized operator<=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcleq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator<=); +#endif + } + + Vectorized operator>(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcgtq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator>); +#endif + } + + Vectorized operator>=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcgeq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator>=); +#endif + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; // Vectorized + +inline std::tuple, Vectorized> convert_half_float( + const Vectorized& a) { + static_assert(Vectorized::size() == 2 * Vectorized::size()); + float16x8_t x = a; + float32x4_t x1 = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t x2 = vcvt_f32_f16(vget_high_f16(x)); + return {Vectorized(x1), Vectorized(x2)}; +} +inline Vectorized convert_float_half( + const Vectorized& a, + const Vectorized& b) { + static_assert(Vectorized::size() == 2 * Vectorized::size()); + float32x4_t x = a; + float32x4_t y = b; + float16x4_t x1 = vcvt_f16_f32(x); + float16x4_t x2 = vcvt_f16_f32(y); + return Vectorized(vcombine_f16(x1, x2)); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_half_float(a); + const auto [b_float_low, b_float_high] = convert_half_float(b); + return convert_float_half( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vaddq_f16(a, b)); +#else + return binary_operator_via_float(std::plus>(), a, b); +#endif +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vsubq_f16(a, b)); +#else + return binary_operator_via_float(std::minus>(), a, b); +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vmulq_f16(a, b)); +#else + return binary_operator_via_float(std::multiplies>(), a, b); +#endif +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vdivq_f16(a, b)); +#else + return binary_operator_via_float(std::divides>(), a, b); +#endif +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vmaxq_f16(a, b)); +#else + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +#endif +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vminq_f16(a, b)); +#else + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +#endif +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +// These are global functions, so the defaults in vec_base.h should +// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available. +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline void convert(const float16_t* src, int16_t* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int16_t* src, float16_t* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vfmaq_f16(c, a, b)); +#else + return a * b + c; +#endif +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vnegq_f16(vfmsq_f16(c, a, b))); +#else + return a * b - c; +#endif +} +#endif // !defined(C10_MOBILE) && defined(__aarch64__) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..4e5ff2efa70c80a90e2d0f897e921803a51b1f6e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h @@ -0,0 +1,307 @@ +#pragma once +// Shared code for bfloat16 and float16. + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +// Shared implementation between Vectorized and +// Vectorized. Uses CRTP to allow derived class +// customization. +template < + typename VecT, + typename ValueT, + template typename BlendRegs, + typename Derived> +struct Vectorized16 { + protected: + VecT values; + + public: + using value_type = ValueT; + using size_type = int; + static constexpr size_type size() { + static_assert(sizeof(VecT) == 8 * sizeof(value_type)); + return 8; + } + + protected: + Derived map2( + const Derived& second, + value_type (*const f)(value_type, value_type)) const { + __at_align__ value_type tmp_first[size()]; + __at_align__ value_type tmp_second[size()]; + static_cast(this)->store( + tmp_first); // store this to tmp_first + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp_first[i] = f(tmp_first[i], tmp_second[i]); + } + return Derived::loadu(tmp_first); + } + + public: + Vectorized16() = default; + Vectorized16(VecT v) : values(v) {} + + operator VecT() const { + return values; + } + + template + static Derived blend(const Derived& a, const Derived& b) { + Derived vec; + vec.values = BlendRegs < 0, + (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 1, + (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 2, + (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 3, + (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values); + + vec.values = BlendRegs < 4, + (mask & 0x10) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 5, + (mask & 0x20) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 6, + (mask & 0x40) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 7, + (mask & 0x80) != 0 > ::impl(a.values, b.values, vec.values); + + return vec; + } + + template + static Derived arange( + value_type base = 0, + step_t step = static_cast(1)) { + const Derived base_vec(base); + const Derived step_vec(step); + const Derived step_sizes( + value_type(0), + value_type(1), + value_type(2), + value_type(3), + value_type(4), + value_type(5), + value_type(6), + value_type(7)); + return fmadd(step_sizes, step_vec, base_vec); + } + + // Very slow implementation of indexing. + // Only required because vec256_qint refers to this. + // Once we specialize that implementation for ARM + // this should be removed. TODO (kimishpatel) + value_type operator[](int idx) const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + return tmp[idx]; + } + + int zero_mask() const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0) { + mask |= (1 << i); + } + } + return mask; + } + + Derived map(value_type (*const f)(value_type)) const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return Derived::loadu(tmp); + } + + Derived angle() const { + auto zero = Derived(0); + auto pi = Derived(c10::pi); + auto tmp = + Derived::blendv(zero, pi, *static_cast(this) < zero); + return Derived::blendv( + tmp, + *static_cast(this), + static_cast(this)->isnan()); + } + Derived real() const { + return *this; + } + Derived imag() const { + return Derived(0); + } + Derived conj() const { + return *this; + } + + // Sleef does not support FP16/BF16, so many math functions are applied by + // converting to FP32, applying the math function, and then converting back to + // FP16/BF16. + Derived acos() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::acos); + } + Derived acosh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::acosh); + } + Derived asin() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::asin); + } + Derived asinh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::asinh); + } + Derived atan() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::atan); + } + Derived atanh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::atanh); + } + Derived atan2(const Derived& exp) const { + return static_cast(this)->map2_with_vec_float_method( + exp, &Vectorized::atan2); + } + Derived copysign(const Derived& sign) const { + return static_cast(this)->map2_with_vec_float_method( + sign, &Vectorized::copysign); + } + Derived erf() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erf); + } + Derived erfc() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erfc); + } + Derived erfinv() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erfinv); + } + Derived exp() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp); + } + Derived exp2() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp2); + } + Derived expm1() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::expm1); + } + Derived exp_u20() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp_u20); + } + Derived fmod(const Derived& q) const { + // This function is questionable with a conversion, so we use map2 + return map2(q, std::fmod); + } + Derived hypot(const Derived& b) const { + return static_cast(this)->map2_with_vec_float_method( + b, &Vectorized::hypot); + } + Derived i0() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::i0); + } + Derived i0e() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::i0e); + } + Derived digamma() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::digamma); + } + Derived igamma(const Derived& x) const { + return static_cast(this)->map2_with_vec_float_method( + x, &Vectorized::igamma); + } + Derived igammac(const Derived& x) const { + return static_cast(this)->map2_with_vec_float_method( + x, &Vectorized::igammac); + } + Derived log() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log); + } + Derived log10() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log10); + } + Derived log1p() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log1p); + } + Derived log2() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log2); + } + Derived nextafter(const Derived& b) const { + // This function does not make sense with conversion, so we use map2 + return map2(b, std::nextafter); + } + Derived sin() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::sin); + } + Derived sinh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::sinh); + } + Derived cos() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::cos); + } + Derived cosh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::cosh); + } + Derived ceil() const { + // This function is questionable with a conversion, so we use map + return map(at::native::ceil_impl); + } + Derived floor() const { + // This function is questionable with a conversion, so we use map + return map(at::native::floor_impl); + } + Derived round() const { + // This function is questionable with a conversion, so we use map + return map(at::native::round_impl); + } + Derived tan() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::tan); + } + Derived tanh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::tanh); + } + Derived lgamma() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::lgamma); + } + Derived rsqrt() const { + return static_cast(this)->sqrt().reciprocal(); + } + Derived pow(const Derived& exp) const { + return static_cast(this)->map2_with_vec_float_method( + exp, &Vectorized::pow); + } +}; + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..b78841ead92e96ee8de032458eab0e38c41a3c60 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h @@ -0,0 +1,396 @@ +/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */ + +__extension__ extern __inline uint8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u8_x2(const uint8_t* __a) { + uint8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s8_x2(const int8_t* __a) { + int8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u16_x2(const uint16_t* __a) { + uint16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s16_x2(const int16_t* __a) { + int16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u32_x2(const uint32_t* __a) { + uint32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s32_x2(const int32_t* __a) { + int32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u64_x2(const uint64_t* __a) { + uint64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s64_x2(const int64_t* __a) { + int64x1x2_t ret; + __builtin_aarch64_simd_oi __o; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f16_x2(const float16_t* __a) { + float16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f32_x2(const float32_t* __a) { + float32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f64_x2(const float64_t* __a) { + float64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p8_x2(const poly8_t* __a) { + poly8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p16_x2(const poly16_t* __a) { + poly16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p64_x2(const poly64_t* __a) { + poly64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u8_x2(const uint8_t* __a) { + uint8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s8_x2(const int8_t* __a) { + int8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u16_x2(const uint16_t* __a) { + uint16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s16_x2(const int16_t* __a) { + int16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u32_x2(const uint32_t* __a) { + uint32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s32_x2(const int32_t* __a) { + int32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u64_x2(const uint64_t* __a) { + uint64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s64_x2(const int64_t* __a) { + int64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f16_x2(const float16_t* __a) { + float16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f32_x2(const float32_t* __a) { + float32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f64_x2(const float64_t* __a) { + float64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p8_x2(const poly8_t* __a) { + poly8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p16_x2(const poly16_t* __a) { + poly16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p64_x2(const poly64_t* __a) { + poly64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +/* vst1x2 */ + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s64_x2(int64_t* __a, int64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f64_x2(float64_t* __a, float64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s8_x2(int8_t* __a, int8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s16_x2(int16_t* __a, int16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s32_x2(int32_t* __a, int32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f16_x2(float16_t* __a, float16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f32_x2(float32_t* __a, float32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s8_x2(int8_t* __a, int8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s16_x2(int16_t* __a, int16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s32_x2(int32_t* __a, int32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s64_x2(int64_t* __a, int64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f16_x2(float16_t* __a, float16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f64_x2(float64_t* __a, float64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..93f1110d808c61c47567984c17b47c3c6100e322 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h @@ -0,0 +1,7 @@ +/* Workaround for missing vst1q_f32_x2 in gcc-8. */ + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h new file mode 100644 index 0000000000000000000000000000000000000000..50c3cc31a6c488b58491948bae88039d88794500 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h @@ -0,0 +1,430 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +#include +#if !( \ + defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || \ + defined(CPU_CAPABILITY_ZVECTOR)) +#if defined(CPU_CAPABILITY_SVE256) +#include +#else +// clang-format off +#include +#include +#include +#include +#endif +#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16) +#include +#endif +#include +#include +#include +// clang-format on +#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) +#include +#else +// clang-format off +#include +#include +#include +// clang-format on +#endif + +#include +#include + +#include +#include +#include +#include +#include + +namespace at::vec { + +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; +} + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << "]"; + return stream; +} + +#if defined(CPU_CAPABILITY_AVX2) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castpd_ps(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castps_pd(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castsi256_ps(src); +} + +template <> +inline Vectorized cast( + const Vectorized& src) { + return _mm256_castsi256_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm256_i64gather_pd(base_addr, vindex, scale); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm256_i32gather_ps(base_addr, vindex, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000)); + return _mm256_sub_epi64( + _mm256_castpd_si256(x), + _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm256_cvttps_epi32(src); +} + +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + __m256i magic_i_lo = _mm256_set1_epi64x(0x4330000000000000); /* 2^52 */ + __m256i magic_i_hi32 = + _mm256_set1_epi64x(0x4530000080000000); /* 2^84 + 2^63 */ + __m256i magic_i_all = + _mm256_set1_epi64x(0x4530000080100000); /* 2^84 + 2^63 + 2^52 */ + __m256d magic_d_all = _mm256_castsi256_pd(magic_i_all); + + __m256i v_lo = _mm256_blend_epi32( + magic_i_lo, src, 0b01010101); /* v_low = low32 + 2^52 */ + __m256i v_hi = _mm256_srli_epi64(src, 32); + v_hi = _mm256_xor_si256( + v_hi, magic_i_hi32); /* v_hi = high32*2^32 + 2^84 + 2^63 */ + /* int64 = low32 + high32*2^32 = v_hi + v_lo - 2^52 - 2^63 - 2^84 */ + __m256d v_hi_dbl = _mm256_sub_pd(_mm256_castsi256_pd(v_hi), magic_d_all); + __m256d result = _mm256_add_pd(v_hi_dbl, _mm256_castsi256_pd(v_lo)); + return result; +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm256_cvtepi32_ps(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + // swap lanes: + // a_swapped = {a0, a1, b0, b1} + // b_swapped = {a2, a3, b2, b3} + auto a_swapped = + _mm256_permute2f128_pd(a, b, 0b0100000); // 0, 2. 4 bits apart + auto b_swapped = + _mm256_permute2f128_pd(a, b, 0b0110001); // 1, 3. 4 bits apart + + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + _mm256_permute4x64_pd(a_swapped, 0b11011000), // 0, 2, 1, 3 + _mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3 +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + + // swap lanes: + // a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3} + // b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7} + // TODO: can we support caching this? + auto a_swapped = + _mm256_permute2f128_ps(a, b, 0b0100000); // 0, 2. 4 bits apart + auto b_swapped = + _mm256_permute2f128_ps(a, b, 0b0110001); // 1, 3. 4 bits apart + + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + return std::make_pair( + _mm256_permutevar8x32_ps(a_swapped, group_ctrl), + _mm256_permutevar8x32_ps(b_swapped, group_ctrl)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + + // group cols crossing lanes: + // a_grouped = {a0, a1, b0, b1} + // b_grouped = {a2, a3, b2, b3} + auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000); // 0, 2, 1, 3 + auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000); // 0, 2, 1, 3 + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + _mm256_permute2f128_pd( + a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart + _mm256_permute2f128_pd( + a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + + // group cols crossing lanes: + // a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3} + // b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7} + // TODO: can we support caching this? + const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl); + auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl); + + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + _mm256_permute2f128_ps( + a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart + _mm256_permute2f128_ps( + a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_ps(v, mask_float); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_epi32(v, mask_int32); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask = _mm256_set_epi8( + 1, + 0, + 3, + 2, + 5, + 4, + 7, + 6, + 9, + 8, + 11, + 10, + 13, + 12, + 15, + 14, + 1, + 0, + 3, + 2, + 5, + 4, + 7, + 6, + 9, + 8, + 11, + 10, + 13, + 12, + 15, + 14); + auto reversed = _mm256_shuffle_epi8(v, mask); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +inline __m256i flip8(const __m256i& v) { + const __m256i mask_int8 = _mm256_set_epi8( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15); + auto reversed = _mm256_shuffle_epi8(v, mask_int8); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m256i* self_ = reinterpret_cast(self.as_bytes()); + const __m256i* other_ = reinterpret_cast(other.as_bytes()); + __m256i out = _mm256_and_si256(*self_, *other_); + Vectorized ret; + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + +#endif // (defined(CPU_CAPABILITY_AVX2) + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h new file mode 100644 index 0000000000000000000000000000000000000000..47f50991d76d445cf76c2b5c87dd40db721cf0f8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h @@ -0,0 +1,829 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +// Used for shared functions and classes for vec256_bfloat16.h and +// vec256_half.h. Any functions/classes that are common between those two files +// should be defined here. Any non-shared functions/classes should be defined in +// the respective files. + +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +#ifndef SLEEF_CONST +#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) +#define SLEEF_CONST const +#else +#define SLEEF_CONST +#endif +#define SLEEF_CONST_OLD SLEEF_CONST +#else +#define SLEEF_CONST_OLD +#endif + +// bfloat16 conversion +static inline void cvtbf16_fp32(const __m128i& a, __m256& o) { + o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) { + __m128i lo = _mm256_extractf128_si256(a, 0); + __m128i hi = _mm256_extractf128_si256(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m128i cvtfp32_bf16(const __m256& src) { + __m256i value = _mm256_castps_si256(src); + __m256i nan = _mm256_set1_epi32(0xffff); + __m256i mask = _mm256_castps_si256(_mm256_cmp_ps(src, src, _CMP_ORD_Q)); + __m256i ones = _mm256_set1_epi32(0x1); + __m256i vec_bias = _mm256_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm256_and_si256(_mm256_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm256_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm256_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm256_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm256_blendv_epi8(nan, t_value, mask); + t_value = + _mm256_packus_epi32(t_value, t_value); // t[4-7] t[4-7] t[0-4] t[0-4] + t_value = _mm256_permute4x64_epi64(t_value, 0xd8); // 11 01 10 00 + return _mm256_castsi256_si128(t_value); +} + +static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) { + __m256i lo = _mm256_castps_si256(a); + __m256i hi = _mm256_castps_si256(b); + __m256i nan = _mm256_set1_epi32(0xffff); + __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); + __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q)); + __m256i ones = _mm256_set1_epi32(0x1); + __m256i vec_bias = _mm256_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm256_and_si256(_mm256_srli_epi32(lo, 16), ones); + auto t_hi = _mm256_and_si256(_mm256_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm256_add_epi32(t_lo, vec_bias); + t_hi = _mm256_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm256_add_epi32(t_lo, lo); + t_hi = _mm256_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm256_srli_epi32(t_lo, 16); + t_hi = _mm256_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm256_blendv_epi8(nan, t_lo, mask_lo); + t_hi = _mm256_blendv_epi8(nan, t_hi, mask_hi); + + t_lo = _mm256_packus_epi32( + t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + return _mm256_permute4x64_epi64(t_lo, 0xd8); // 11 01 10 00 +} + +static inline __m256i merge_compare_result(const __m256& a, const __m256& b) { + __m256i lo = _mm256_castps_si256(a); + __m256i hi = _mm256_castps_si256(b); + lo = _mm256_srli_epi32(lo, 16); + hi = _mm256_srli_epi32(hi, 16); + auto out = _mm256_packus_epi32(lo, hi); + return _mm256_permute4x64_epi64(out, 0xd8); +} + +// float16 conversion +static inline void cvtfp16_fp32(const __m128i& a, __m256& o) { + o = _mm256_cvtph_ps(a); +} + +static inline void cvtfp16_fp32(const __m256i& a, __m256& o1, __m256& o2) { + __m128i lo = _mm256_extractf128_si256(a, 0); + __m128i hi = _mm256_extractf128_si256(a, 1); + cvtfp16_fp32(lo, o1); + cvtfp16_fp32(hi, o2); +} + +static inline __m128i cvtfp32_fp16(const __m256& src) { + return _mm256_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +static inline __m256i cvtfp32_fp16(const __m256& a, const __m256& b) { + __m128i lo = + _mm256_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i hi = + _mm256_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1); +} + +// dtype conversion between float16/bfloat16 and float32 +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m128i& a, __m256& o); +template <> +inline void cvt_to_fp32(const __m128i& a, __m256& o) { + cvtbf16_fp32(a, o); +} +template <> +inline void cvt_to_fp32(const __m128i& a, __m256& o) { + cvtfp16_fp32(a, o); +} + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2); +template <> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) { + cvtbf16_fp32(a, o1, o2); +} +template <> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) { + cvtfp16_fp32(a, o1, o2); +} + +template < + typename T, + bool is_compare_op = false, + typename std::enable_if_t, int> = 0> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b); +template <> +inline __m256i cvt_from_fp32( + const __m256& a, + const __m256& b) { + return cvtfp32_bf16(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return merge_compare_result(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return cvtfp32_fp16(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return cvtfp32_fp16(a, b); +} + +template +class Vectorized16 { + static_assert( + is_reduced_floating_point_v, + "Support only float16 and bfloat16."); + + protected: + __m256i values; + + public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized16() {} + Vectorized16(__m256i v) : values(v) {} + Vectorized16(T val) { + value_type uw = val.x; + values = _mm256_set1_epi16(uw); + } + Vectorized16( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16) { + values = _mm256_setr_epi16( + val1.x, + val2.x, + val3.x, + val4.x, + val5.x, + val6.x, + val7.x, + val8.x, + val9.x, + val10.x, + val11.x, + val12.x, + val13.x, + val14.x, + val15.x, + val16.x); + } + operator __m256i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256i cmp = _mm256_cmpeq_epi16(values, _mm256_set1_epi16(0)); + return _mm256_movemask_epi8(cmp); + } + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) + return _mm256_loadu_si256(reinterpret_cast(ptr)); + + __at_align__ int16_t tmp_values[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (const auto i : c10::irange(count, size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return _mm256_loadu_si256(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi16(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi16(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi16(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi16(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi16(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi16(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi16(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi16(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi16(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi16(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi16(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi16(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi16(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi16(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi16(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi16(b.values, 15); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + T base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + + // 'const' type qualifier on return type has no effect, but sleef defines this + // this way For example `Sleef_exp2f8_u10` signature is `const __m256 + // (__m256)` + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wignored-qualifiers") + Vectorized map(SLEEF_CONST __m256 (*SLEEF_CONST_OLD vop)(__m256)) const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvt_from_fp32(o1, o2); + } + C10_DIAGNOSTIC_POP() + Vectorized isnan() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + lo = _mm256_cmp_ps(lo, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + hi = _mm256_cmp_ps(hi, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + return merge_compare_result(lo, hi); + } + Vectorized abs() const { + return _mm256_andnot_si256(_mm256_set1_epi16(0x8000), values); + } + Vectorized angle() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto angle_lambda = [](__m256 values_2) { + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values_2, values_2, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(c10::pi); + + const auto neg_mask = _mm256_cmp_ps(values_2, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf8_u10); + } + Vectorized acosh() const { + return map(Sleef_acoshf8_u10); + } + Vectorized asin() const { + return map(Sleef_asinf8_u10); + } + Vectorized atan() const { + return map(Sleef_atanf8_u10); + } + Vectorized atanh() const { + return map(Sleef_atanhf8_u10); + } + Vectorized atan2(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f8_u10(lo, b1); + auto o2 = Sleef_atan2f8_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized copysign(const Vectorized& sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m256i mask_value = _mm256_set1_epi32(~0x80008000); + __m256i mask_signbit = _mm256_set1_epi32(0x80008000); + return Vectorized(_mm256_or_si256( + _mm256_and_si256(values, mask_value), + _mm256_and_si256(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff8_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf8_u15); + } + Vectorized erfinv() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf8_u10); + } + Vectorized exp2() const { + return map(Sleef_exp2f8_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f8_u10); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + __m256 x_lo, x_hi; + cvt_to_fp32(values, x_lo, x_hi); + __m256 q_lo, q_hi; + cvt_to_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf8(x_lo, q_lo); + auto o2 = Sleef_fmodf8(x_hi, q_hi); + return cvt_from_fp32(o1, o2); + } + Vectorized hypot(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf8_u05(lo, b1); + auto o2 = Sleef_hypotf8_u05(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0e() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm256_loadu_ps(tmp1); + const auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized digamma() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_digamma(tmp1[i]); + tmp2[i] = calc_digamma(tmp2[i]); + } + const auto o1 = _mm256_loadu_ps(tmp1); + const auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized igamma(const Vectorized& x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + + Vectorized igammac(const Vectorized& x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf8_u10); + } + Vectorized log2() const { + return map(Sleef_log2f8_u10); + } + Vectorized log10() const { + return map(Sleef_log10f8_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf8_u10); + } + Vectorized sin() const { + return map(Sleef_sinf8_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf8_u10); + } + Vectorized cos() const { + return map(Sleef_cosf8_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf8_u10); + } + Vectorized ceil() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_ceil_ps(lo); + auto o2 = _mm256_ceil_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized floor() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_floor_ps(lo); + auto o2 = _mm256_floor_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized neg() const { + return _mm256_xor_si256(values, _mm256_set1_epi16(0x8000)); + } + Vectorized round() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = + _mm256_round_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = + _mm256_round_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf8_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf8_u10); + } + Vectorized trunc() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf8_u10); + } + Vectorized sqrt() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_sqrt_ps(lo); + auto o2 = _mm256_sqrt_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized reciprocal() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm256_set1_ps(1); + auto o1 = _mm256_div_ps(ones, lo); + auto o2 = _mm256_div_ps(ones, hi); + return cvt_from_fp32(o1, o2); + } + Vectorized rsqrt() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm256_set1_ps(1); + auto o1 = _mm256_div_ps(ones, _mm256_sqrt_ps(lo)); + auto o2 = _mm256_div_ps(ones, _mm256_sqrt_ps(hi)); + return cvt_from_fp32(o1, o2); + } + Vectorized pow(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_powf8_u10(lo, b1); + auto o2 = Sleef_powf8_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + + private: + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvt_to_fp32(values, a_lo, a_hi); + cvt_to_fp32(b.values, b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); + } + + public: + Vectorized inline operator>(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_GT_OQ); + }); + } + Vectorized inline operator<(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_LT_OQ); + }); + } + Vectorized inline operator>=(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_GE_OQ); + }); + } + Vectorized inline operator<=(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_LE_OQ); + }); + } + Vectorized inline operator==(const Vectorized16& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_EQ_OQ); + }); + } + Vectorized inline operator!=(const Vectorized16& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ); + }); + } +}; + +template +static inline Vectorized binary_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvt_to_fp32(__m256i(a), a_lo, a_hi); + cvt_to_fp32(__m256i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); +} + +#define CONVERT_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + __m256 o1, o2; \ + cvt_to_fp32(__m256i(a), o1, o2); \ + return std::make_tuple(o1, o2); \ + } \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + return cvt_from_fp32(__m256(a), __m256(b)); \ + } + +#define LOAD_FP32_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + auto values = _mm_loadu_si128(reinterpret_cast(data)); \ + __m256 out_values; \ + cvt_to_fp32(values, out_values); \ + out = out_values; \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + auto vec = Vectorized::loadu(data); \ + __m256 out1_values, out2_values; \ + cvt_to_fp32(vec, out1_values, out2_values); \ + out1 = out1_values; \ + out2 = out2_values; \ + } + +#else // CPU_CAPABILITY_AVX2 + +#define CONVERT_NON_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr2); \ + convert(arr2, arr, K); \ + return std::make_tuple( \ + Vectorized::loadu(arr), \ + Vectorized::loadu(arr + Vectorized::size())); \ + } \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr); \ + b.store(arr + Vectorized::size()); \ + convert(arr, arr2, K); \ + return Vectorized::loadu(arr2); \ + } + +#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + __at_align__ float values[Vectorized::size()]; \ + for (const auto k : c10::irange(Vectorized::size())) { \ + values[k] = data[k]; \ + } \ + out = Vectorized::loadu(values); \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + load_fp32_from_##name(data, out1); \ + data += Vectorized::size(); \ + load_fp32_from_##name(data, out2); \ + } + +#endif // CPU_CAPABILITY_AVX2 +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..1306270de7147b317f8b98b435cc22366f7ff36a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -0,0 +1,280 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = BFloat16; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_si256(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_si256(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_si256(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + auto max_lo = _mm256_max_ps(a_lo, b_lo); + auto max_hi = _mm256_max_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(max_lo, nan_lo); + auto o2 = _mm256_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + auto min_lo = _mm256_min_ps(a_lo, b_lo); + auto min_hi = _mm256_min_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(min_lo, nan_lo); + auto o2 = _mm256_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + __m256 max_lo, max_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(min), min_lo, min_hi); + cvtbf16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo)); + auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 max_lo, max_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, a_lo); + auto o2 = _mm256_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(min), min_lo, min_hi); + auto o1 = _mm256_max_ps(min_lo, a_lo); + auto o2 = _mm256_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, BFloat16* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = _mm256_loadu_ps(&src[i]); + __m256 b = _mm256_loadu_ps(&src[i + 8]); + + __m256i bf = cvtfp32_bf16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, BFloat16* dst, int64_t n) { + auto load_float = [](const double* src) -> __m256 { + // Load one float vector from an array of doubles + __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src)); + __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4)); + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = load_float(&src[i]); + __m256 b = load_float(&src[i + 8]); + + __m256i bf = cvtfp32_bf16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + __m256 c_lo, c_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + cvtbf16_fp32(__m256i(c), c_lo, c_hi); + auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +CONVERT_VECTORIZED_INIT(BFloat16, bfloat16) +LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) + +#else // defined(CPU_CAPABILITY_AVX2) + +#if !( \ + defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE256)) +CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) +#endif + +LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) +#endif // defined(CPU_CAPABILITY_AVX2) +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h new file mode 100644 index 0000000000000000000000000000000000000000..fbaf2418cfa2eb5bc8c9ce441dc2f8bc5d44f7de --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -0,0 +1,536 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m256d values; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 2; + } + Vectorized() {} + Vectorized(__m256d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm256_setr_pd(real_value, imag_value, real_value, imag_value); + } + Vectorized(c10::complex val1, c10::complex val2) { + values = _mm256_setr_pd(val1.real(), val1.imag(), val2.real(), val2.imag()); + } + operator __m256d() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 4, "Unexpected mask value"); + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_pd(a.values, b.values, 0x03); + case 2: + return _mm256_blend_pd(a.values, b.values, 0x0c); + case 3: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values); + return _mm256_blendv_pd(a.values, b.values, mask_); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>(base, base + step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm256_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm256_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2 * size()]; + _mm256_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256d abs_2_() const { + auto val_2 = _mm256_mul_pd(values, values); // a*a b*b + return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m256d abs_() const { + auto real = _mm256_movedup_pd(values); // real real + // movehdup_pd does not exist... + auto imag = _mm256_permute_pd(values, 0xf); // imag imag + return Sleef_hypotd4_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm256_and_pd(abs_(), real_mask); // abs 0 + } + __m256d angle_() const { + // angle = atan2(b/a) + auto b_a = _mm256_permute_pd(values, 0x05); // b a + return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle + return _mm256_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm256_setzero_pd(); + auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ); + auto div = _mm256_div_pd(values, abs); + return _mm256_blendv_pd(div, zero, mask); + } + __m256d real_() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm256_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m256d imag_() const { + const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF)); + return _mm256_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm256_permute_pd(imag_(), 0x05); // b a + } + __m256d conj_() const { + const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + return _mm256_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m256d log2_ = _mm256_set1_pd(std::log(2)); + return _mm256_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m256d log10_ = _mm256_set1_pd(std::log(10)); + return _mm256_div_pd(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m256d one = _mm256_set1_pd(1); + + // auto conj = conj_(); + // auto b_a = _mm256_permute_pd(conj, 0x05); //-b a + // auto ab = _mm256_mul_pd(conj, b_a); //-ab + // -ab auto im = _mm256_add_pd(ab, ab); //-2ab -2ab + + // auto val_2 = _mm256_mul_pd(values, values); // a*a + // b*b auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05)); // + // a*a-b*b b*b-a*a re = _mm256_sub_pd(one, re); + + // auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt(); //sqrt(re + + // i*im) auto ln = Vectorized(_mm256_add_pd(b_a, root)).log(); //ln(iz + + // sqrt()) return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj(); + // //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0); + return _mm256_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expd4_u10(values); //exp(a) exp(b) exp = + // _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, + // 0x05), + // sin_cos.x, 0x0A); //cos(b) sin(b) + // return _mm256_mul_pd(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m256d ln_2 = _mm256_set1_pd(c10::ln_2); + Vectorized> scaled_values = + _mm256_mul_pd(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm256_ceil_pd(values); + } + Vectorized> floor() const { + return _mm256_floor_pd(values); + } + Vectorized> neg() const { + auto zero = _mm256_setzero_pd(); + return _mm256_sub_pd(zero, values); + } + Vectorized> round() const { + return _mm256_round_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); + } + Vectorized> operator!=( + const Vectorized>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + } + Vectorized> operator<( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>&) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_add_pd(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_sub_pd(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm256_mul_pd(a, b); // ac bd + + auto d_c = _mm256_permute_pd(b, 0x05); // d c + d_c = _mm256_xor_pd(sign_mask, d_c); // d -c + auto ad_bc = _mm256_mul_pd(a, d_c); // ad -bc + + auto ret = _mm256_hsub_pd(ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm256_set1_pd(-0.f); + // auto fabs_cd = _mm256_andnot_pd(mask, b); // |c| |d| + // auto fabs_dc = _mm256_permute_pd(fabs_cd, 0x05); // |d| |c| + // auto scale = _mm256_div_pd(_mm256_set1_pd(1.0f), _mm256_max_pd(fabs_cd, + // fabs_dc)); // 1/sc 1/sc auto a2 = _mm256_mul_pd(a, scale); // + // a/sc b/sc auto b2 = _mm256_mul_pd(b, scale); // c/sc d/sc + // auto acbd2 = _mm256_mul_pd(a2, b2); + + // const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0); + // auto dc2 = _mm256_permute_pd(b2, 0x05); // d/sc c/sc + // dc2 = _mm256_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm256_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = _mm256_hadd_pd(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm256_div_pd(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm256_loadu_pd(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + // auto c_d = _mm256_xor_pd(sign_mask, values); //c -d + // return _mm256_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5); + + // auto sum = Vectorized(_mm256_add_pd(i, values)); // a + // 1+b auto sub = Vectorized(_mm256_sub_pd(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(max, isnan); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(min, isnan); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_and_pd(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_or_pd(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_xor_pd(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm256_set1_pd(1.0)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm256_set1_pd(1.0)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h new file mode 100644 index 0000000000000000000000000000000000000000..7a20ff9b319bbeb4351044065be6f3637c9b9074 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -0,0 +1,618 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m256 values; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m256 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm256_setr_ps( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4) { + values = _mm256_setr_ps( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag()); + } + operator __m256() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 16, "Unexpected mask range"); + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_ps( + a.values, b.values, 0x03); // b0000 0001 = b0000 0011 + case 2: + return _mm256_blend_ps( + a.values, b.values, 0x0C); // b0000 0010 = b0000 1100 + case 3: + return _mm256_blend_ps( + a.values, b.values, 0x0F); // b0000 0011 = b0000 1111 + case 4: + return _mm256_blend_ps( + a.values, b.values, 0x30); // b0000 0100 = b0011 0000 + case 5: + return _mm256_blend_ps( + a.values, b.values, 0x33); // b0000 0101 = b0011 0011 + case 6: + return _mm256_blend_ps( + a.values, b.values, 0x3C); // b0000 0110 = b0011 1100 + case 7: + return _mm256_blend_ps( + a.values, b.values, 0x3F); // b0000 0111 = b0011 1111 + case 8: + return _mm256_blend_ps( + a.values, b.values, 0xC0); // b0000 1000 = b1100 0000 + case 9: + return _mm256_blend_ps( + a.values, b.values, 0xC3); // b0000 1001 = b1100 0011 + case 10: + return _mm256_blend_ps( + a.values, b.values, 0xCC); // b0000 1010 = b1100 1100 + case 11: + return _mm256_blend_ps( + a.values, b.values, 0xCF); // b0000 1011 = b1100 1111 + case 12: + return _mm256_blend_ps( + a.values, b.values, 0xF0); // b0000 1100 = b1111 0000 + case 13: + return _mm256_blend_ps( + a.values, b.values, 0xF3); // b0000 1101 = b1111 0011 + case 14: + return _mm256_blend_ps( + a.values, b.values, 0xFC); // b0000 1110 = b1111 1100 + default: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values); + return _mm256_blendv_ps(a.values, b.values, mask_); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + step, + base + c10::complex(2) * step, + base + c10::complex(3) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm256_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm256_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2 * size()]; + _mm256_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256 abs_2_() const { + auto val_2 = _mm256_mul_ps(values, values); // a*a b*b + auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return _mm256_permute_ps(ret, 0xD8); + } + __m256 abs_() const { + auto real = _mm256_moveldup_ps(values); // real real + auto imag = _mm256_movehdup_ps(values); // imag imag + return Sleef_hypotf8_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm256_and_ps(abs_(), real_mask); // abs 0 + } + __m256 angle_() const { + // angle = atan2(b/a) + auto b_a = _mm256_permute_ps(values, 0xB1); // b a + return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm256_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm256_setzero_ps(); + auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ); + auto div = _mm256_div_ps(values, abs); + return _mm256_blendv_ps(div, zero, mask); + } + __m256 real_() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm256_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m256 imag_() const { + const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF)); + return _mm256_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm256_permute_ps(imag_(), 0xB1); // b a + } + __m256 conj_() const { + const __m256 sign_mask = + _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm256_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m256 log2_ = _mm256_set1_ps(std::log(2)); + return _mm256_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m256 log10_ = _mm256_set1_ps(std::log(10)); + return _mm256_div_ps(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m256 one = _mm256_set1_ps(1); + + // auto conj = conj_(); + // auto b_a = _mm256_permute_ps(conj, 0xB1); //-b a + // auto ab = _mm256_mul_ps(conj, b_a); //-ab + // -ab auto im = _mm256_add_ps(ab, ab); //-2ab -2ab + + // auto val_2 = _mm256_mul_ps(values, values); // a*a + // b*b auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1)); // + // a*a-b*b b*b-a*a re = _mm256_permute_ps(re, 0xD8); re = + // _mm256_sub_ps(one, re); + + // auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt(); //sqrt(re + + // i*im) auto ln = Vectorized(_mm256_add_ps(b_a, root)).log(); //ln(iz + + // sqrt()) return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj(); + // //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expf8_u10(values); //exp(a) exp(b) exp = + // _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, + // 0xB1), + // sin_cos.x, 0xAA); //cos(b) sin(b) + // return _mm256_mul_ps(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m256 ln_2 = _mm256_set1_ps(c10::ln_2); + Vectorized> scaled_values = _mm256_mul_ps(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm256_ceil_ps(values); + } + Vectorized> floor() const { + return _mm256_floor_ps(values); + } + Vectorized> neg() const { + auto zero = _mm256_setzero_ps(); + return _mm256_sub_ps(zero, values); + } + Vectorized> round() const { + return _mm256_round_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); + } + Vectorized> operator!=( + const Vectorized>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + } + Vectorized> operator<( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_add_ps(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_sub_ps(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256 sign_mask = + _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm256_mul_ps(a, b); // ac bd + + auto d_c = _mm256_permute_ps(b, 0xB1); // d c + d_c = _mm256_xor_ps(sign_mask, d_c); // d -c + auto ad_bc = _mm256_mul_ps(a, d_c); // ad -bc + + auto ret = _mm256_hsub_ps(ac_bd, ad_bc); // ac - bd ad + bc + ret = _mm256_permute_ps(ret, 0xD8); + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm256_set1_ps(-0.f); + // auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d| + // auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c| + // auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc + // auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc + // auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc + // auto acbd2 = _mm256_mul_ps(a2, b2); + + // const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc + // dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 + // res2 = _mm256_permute_ps(res2, 0xD8); + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm256_div_ps(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm256_loadu_ps(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); auto c_d = _mm256_xor_ps(sign_mask, values); //c -d + // return _mm256_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm256_add_ps(i, values)); // a + // 1+b auto sub = Vectorized(_mm256_sub_ps(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_and_ps(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_or_ps(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_xor_ps(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm256_set1_ps(1.0f)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm256_set1_ps(1.0f)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..41425639a75f6e8cbc9b797216c564bd181137d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h @@ -0,0 +1,365 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + __m256 value; + cvtbf16_fp32(_mm256_castsi256_si128(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + __m256 value; + cvtfp16_fp32(_mm256_castsi256_si128(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = _mm256_castsi128_si256(cvtfp32_bf16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = _mm256_castsi128_si256(cvtfp32_fp16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + +template <> +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src); + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low_double = at::vec::convert_to_fp_of_same_size(src[0]); + auto low = _mm256_cvtpd_ps(low_double); + auto high_double = at::vec::convert_to_fp_of_same_size(src[1]); + auto high = _mm256_cvtpd_ps(high_double); + return Vectorized( + _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + // Scalarization is the most reliable way of converting fp to int64 on AVX2. + // Check: https://stackoverflow.com/questions/41144668 + float buffer[8]; + src.store(buffer); + at::vec::VectorizedN result; + result[0] = Vectorized( + static_cast(buffer[0]), + static_cast(buffer[1]), + static_cast(buffer[2]), + static_cast(buffer[3])); + result[1] = Vectorized( + static_cast(buffer[4]), + static_cast(buffer[5]), + static_cast(buffer[6]), + static_cast(buffer[7])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0)); + auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0)); + auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0)); + auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0)); + return Vectorized(_mm256_blend_epi32(low_perm, high_perm, 0xF0)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src[0])); + result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src[0], 1)); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepi8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepu8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm256_cvttps_epi32(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm256_cvtepi32_ps(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepu8_epi16(src128)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + src_t, + 1, + typename std::enable_if_t< + (is_reduced_floating_point_v && is_8bit_integer_v) || + (is_reduced_floating_point_v && is_8bit_integer_v), + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2)); + __m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1); + // Shuffle [191:128] bit from combined in to [127:64] bit of result + __m256i result = + _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000); + return at::vec::Vectorized(result); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_float_to_int8(src[0]); + } +}; + +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + // Shuffle [127:64] bit from src[0] in to [191:128] bit of shuffled + __m256i shuffled = _mm256_permute4x64_epi64(src[0], 0b11011000); + __m256i src2 = + _mm256_castsi128_si256(_mm_castps_si128(_mm256_extractf128_ps( + _mm256_castsi256_ps(shuffled), 1) // Extract the second 128-bit lane + )); + return VectorizedN( + convert_int8_to_float(src[0]), + convert_int8_to_float(src2)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + int64_t, + 2, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const VectorizedN& src) { + return VecConvert::apply( + VecConvert::apply(src)); + } +}; + +#endif /* defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) */ + +#if (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)) +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_to_float(src[0]); + } +}; +#endif + +#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + // Load 16-bit unsigned integers from src into an SVE vector + svuint16_t u16x4 = + svld1_u16(svptrue_b16(), reinterpret_cast(&src[0])); + // Zero-extend to 32-bit SVE does not have direct vmovl_u16 equivalent. + vls_uint32_t u32x4 = + svreinterpret_u32_u16(svzip1_u16(svdup_n_u16(0), u16x4)); + // Reinterpret as float32 + vls_float32_t f32x4 = svreinterpret_f32_u32(u32x4); + res[0] = Vectorized(f32x4); + return res; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + std::tie(res[0], res[1]) = convert_bfloat16_float(src[0]); + return res; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + res[0] = convert_float_bfloat16(src[0], src[1]); + return res; + } +}; + +#endif // defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + auto [res_vec1, res_vec2] = convert_to_float(src[0]); + return res_vec1; + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_from_float(src[0], src[0]); + } +}; + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h new file mode 100644 index 0000000000000000000000000000000000000000..528e56d93a1d20c3f95819655e69afdb6c7e213d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h @@ -0,0 +1,505 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + __m256d values; + + public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m256d v) : values(v) {} + Vectorized(double val) { + values = _mm256_set1_pd(val); + } + Vectorized(double val1, double val2, double val3, double val4) { + values = _mm256_setr_pd(val1, val2, val3, val4); + } + operator __m256d() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm256_blend_pd(a.values, b.values, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_pd(a.values, b.values, mask.values); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(double)); + return _mm256_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[size()]; + _mm256_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(double)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ); + return _mm256_movemask_pd(cmp); + } + Vectorized isnan() const { + return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q); + } + bool has_inf_nan() const { + __m256d self_sub = _mm256_sub_pd(values, values); + return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != + 0; + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm256_set1_pd(-0.f); + return _mm256_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm256_set1_pd(0.f); + const auto nan_vec = _mm256_set1_pd(NAN); + const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_pd(c10::pi); + + const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask); + angle = _mm256_blendv_pd(angle, nan_vec, nan_mask); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd4_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshd4_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind4_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhd4_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand4_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhd4_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2d4_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignd4(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd4_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd4_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd4_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2d4_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d4_u10(values)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd4(values, q)); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotd4_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd4_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d4_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d4_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd4_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind4_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd4_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd4_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd4_u10(values)); + } + Vectorized ceil() const { + return _mm256_ceil_pd(values); + } + Vectorized floor() const { + return _mm256_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm256_xor_pd(_mm256_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterd4(values, b)); + } + Vectorized round() const { + return _mm256_round_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand4_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd4_u10(values)); + } + Vectorized trunc() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad4_u10(values)); + } + Vectorized sqrt() const { + return _mm256_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm256_div_pd(_mm256_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powd4_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); + } + + Vectorized operator!=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + } + + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ); + } + + Vectorized operator<=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ); + } + + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ); + } + + Vectorized operator>=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_pd(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_pd(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mul_pd(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm256_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + Vectorized max = _mm256_max_pd(a, b); + Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + Vectorized min = _mm256_min_pd(a, b); + Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm256_min_pd(max, _mm256_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm256_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm256_min_pd(max, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_pd(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_pd(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_pd(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +#ifdef CPU_CAPABILITY_AVX2 +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmadd_pd(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmsub_pd(a, b, c); +} +#endif + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h new file mode 100644 index 0000000000000000000000000000000000000000..ac2f1dd919f6acfea8f4af36a7d375ad7530cf15 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h @@ -0,0 +1,768 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + __m256 values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m256 v) : values(v) {} + Vectorized(float val) { + values = _mm256_set1_ps(val); + } + Vectorized( + float val1, + float val2, + float val3, + float val4, + float val5, + float val6, + float val7, + float val8) { + values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8); + } + Vectorized(const float (&arr)[8]) + : Vectorized( + arr[0], + arr[1], + arr[2], + arr[3], + arr[4], + arr[5], + arr[6], + arr[7]) {} + operator __m256() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm256_blend_ps(a.values, b.values, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_ps(a.values, b.values, mask.values); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_ps(reinterpret_cast(ptr)); + __at_align__ float tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, reinterpret_cast(ptr), count * sizeof(float)); + return _mm256_loadu_ps(tmp_values); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm256_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[size()]; + _mm256_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + return _mm256_movemask_ps(cmp); + } + Vectorized isnan() const { + return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + } + + bool has_inf_nan() const { + __m256 self_sub = _mm256_sub_ps(values, values); + return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != + 0; + } + + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm256_set1_ps(-0.f); + return _mm256_andnot_ps(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(c10::pi); + + const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf8_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshf8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf8_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhf8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf8_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhf8_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2f8_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignf8(values, sign)); + } + Vectorized erf() const { + // constants + const auto neg_zero_vec = _mm256_set1_ps(-0.f); + const auto one_vec = _mm256_set1_ps(1.0f); + const auto p = _mm256_set1_ps(0.3275911f); + const auto p1 = _mm256_set1_ps(0.254829592f); + const auto p2 = _mm256_set1_ps(-0.284496736f); + const auto p3 = _mm256_set1_ps(1.421413741f); + const auto p4 = _mm256_set1_ps(-1.453152027f); + const auto p5 = _mm256_set1_ps(1.061405429f); + // sign(x) + auto sign_mask = _mm256_and_ps(neg_zero_vec, values); + auto abs_vec = _mm256_xor_ps(sign_mask, values); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = _mm256_fmadd_ps(p, abs_vec, one_vec); + auto t = _mm256_div_ps(one_vec, tmp0); + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = _mm256_fmadd_ps(p5, t, p4); + auto tmp2 = _mm256_fmadd_ps(tmp1, t, p3); + auto tmp3 = _mm256_fmadd_ps(tmp2, t, p2); + auto r = _mm256_fmadd_ps(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = _mm256_mul_ps(values, values); + auto neg_pow_2 = _mm256_xor_ps(neg_zero_vec, pow_2); + // auto tmp4 = exp(neg_pow_2); + auto tmp4 = Vectorized(Sleef_expf8_u10(neg_pow_2)); + auto tmp5 = _mm256_xor_ps(neg_zero_vec, tmp4); + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = _mm256_mul_ps(tmp5, t); + auto tmp7 = _mm256_fmadd_ps(tmp6, r, one_vec); + return _mm256_xor_ps(sign_mask, tmp7); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf8_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2f8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f8_u10(values)); + } + Vectorized exp_u20() const { + // A faster version of exp with ULP=20 + const __m256 vec_factorial_1 = + _mm256_set1_ps(0.999999701f); // 1/factorial(1) + const __m256 vec_factorial_2 = + _mm256_set1_ps(0.499991506f); // 1/factorial(2) + const __m256 vec_factorial_3 = + _mm256_set1_ps(0.166676521f); // 1/factorial(3) + const __m256 vec_factorial_4 = + _mm256_set1_ps(0.0418978221f); // 1/factorial(4) + const __m256 vec_factorial_5 = + _mm256_set1_ps(0.00828929059f); // 1/factorial(5) + const __m256 vec_exp_log2ef = + _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e) + const __m256 vec_half = _mm256_set1_ps(0.5f); + const __m256 vec_one = _mm256_set1_ps(1.f); + const __m256 vec_zero = _mm256_set1_ps(0.f); + const __m256 vec_two = _mm256_set1_ps(2.f); + const __m256 vec_ln2f = + _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2) + const __m256 vec_ln_flt_min = + _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); + const __m256 vec_ln_flt_max = + _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); + const __m256i vec_127 = _mm256_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; + + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + auto less_ln_flt_min_mask = + _mm256_cmp_ps(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); + auto vec_src = _mm256_min_ps(values, vec_ln_flt_max); + vec_src = _mm256_max_ps(vec_src, vec_ln_flt_min); + + // fx = floorf(x * log2ef + 0.5) + auto vec_fx = _mm256_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); + vec_fx = _mm256_floor_ps(vec_fx); + + // x = x - fx * ln2 + auto vec_exp_poly = _mm256_fnmadd_ps(vec_fx, vec_ln2f, vec_src); + + // compute polynomial + auto vec_res = + _mm256_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_one); + + // compute 2^(n-1) + auto vec_exp_number = _mm256_sub_ps(vec_fx, vec_one); + auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number); + auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127); + vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); + auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i); + vec_two_pow_n = + _mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask); + + // y = y * 2^n + vec_res = _mm256_mul_ps(vec_res, vec_two_pow_n); + vec_res = _mm256_mul_ps(vec_res, vec_two); + return vec_res; + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf8(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf8_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf8_u35(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf8_u35(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf8_u10(values)); + } + Vectorized ceil() const { + return _mm256_ceil_ps(values); + } + Vectorized floor() const { + return _mm256_floor_ps(values); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotf8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterf8(values, b)); + } + Vectorized round() const { + return _mm256_round_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf8_u10(values)); + } + Vectorized trunc() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf8_u10(values)); + } + Vectorized sqrt() const { + return _mm256_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm256_div_ps(_mm256_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powf8_u10(values, b)); + } + float reduce_add() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = _mm256_add_ps(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = _mm256_add_ps(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = _mm256_add_ps(v, v1); + return _mm256_cvtss_f32(v); + } + float reduce_max() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = _mm256_max_ps(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = _mm256_max_ps(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = _mm256_max_ps(v, v1); + return _mm256_cvtss_f32(v); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); + } + + Vectorized operator!=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + } + + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ); + } + + Vectorized operator<=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ); + } + + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ); + } + + Vectorized operator>=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_ps(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_ps(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mul_ps(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm256_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + Vectorized max = _mm256_max_ps(a, b); + Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + Vectorized min = _mm256_min_ps(a, b); + Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm256_min_ps(max, _mm256_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm256_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm256_max_ps(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_ps(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_ps(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_ps(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmadd_ps(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmsub_ps(a, b, c); +} + +// TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen for micro gemm +inline void transpose_block(at::vec::VectorizedN& input) { + __m256 temp0[8]; + // unpacking and interleaving 32-bit elements + // a0 b0 a1 b1 a4 b4 a5 b5 + // a2 b2 a3 b3 a6 b6 a7 b7 + // c0 d0 c1 d1 ... + // c2 d2 c3 d3 ... + // e0 f0 e1 f1 ... + // e2 f2 e3 f3 ... + // g0 h0 g1 h1 ... + // g2 h2 g3 h3 ... + temp0[0] = _mm256_unpacklo_ps(input[0], input[1]); + temp0[1] = _mm256_unpackhi_ps(input[0], input[1]); + temp0[2] = _mm256_unpacklo_ps(input[2], input[3]); + temp0[3] = _mm256_unpackhi_ps(input[2], input[3]); + temp0[4] = _mm256_unpacklo_ps(input[4], input[5]); + temp0[5] = _mm256_unpackhi_ps(input[4], input[5]); + temp0[6] = _mm256_unpacklo_ps(input[6], input[7]); + temp0[7] = _mm256_unpackhi_ps(input[6], input[7]); + + __m256 temp1[8]; + // unpacking and interleaving 64-bit elements + // a0 b0 c0 d0 a4 b4 c4 d4 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // e0 f0 g0 h0 e4 f4 g4 h4 + // e1 f1 g1 h1 ... + // e2 f2 g2 h2 ... + // e3 f3 g3 h3 ... + temp1[0] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[0]), _mm256_castps_pd(temp0[2]))); + temp1[1] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[0]), _mm256_castps_pd(temp0[2]))); + temp1[2] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[1]), _mm256_castps_pd(temp0[3]))); + temp1[3] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[1]), _mm256_castps_pd(temp0[3]))); + temp1[4] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[4]), _mm256_castps_pd(temp0[6]))); + temp1[5] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[4]), _mm256_castps_pd(temp0[6]))); + temp1[6] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[5]), _mm256_castps_pd(temp0[7]))); + temp1[7] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[5]), _mm256_castps_pd(temp0[7]))); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 e0 f0 g0 h0 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // a4 b4 c4 d4 ... + // a5 b5 c5 d5 ... + // a6 b6 c6 d6 ... + // a7 b7 c7 d7 ... + input[0] = _mm256_permute2f128_ps(temp1[0], temp1[4], 0x20); + input[1] = _mm256_permute2f128_ps(temp1[1], temp1[5], 0x20); + input[2] = _mm256_permute2f128_ps(temp1[2], temp1[6], 0x20); + input[3] = _mm256_permute2f128_ps(temp1[3], temp1[7], 0x20); + input[4] = _mm256_permute2f128_ps(temp1[0], temp1[4], 0x31); + input[5] = _mm256_permute2f128_ps(temp1[1], temp1[5], 0x31); + input[6] = _mm256_permute2f128_ps(temp1[2], temp1[6], 0x31); + input[7] = _mm256_permute2f128_ps(temp1[3], temp1[7], 0x31); +} + +// Used by Inductor CPP codegen +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + // load from src to registers + at::vec::VectorizedN input; + // a: a0 a1 a2 a3 a4 a5 a6 a7 + // b: b0 b1 b2 b3 b4 b5 b6 b7 + // c: c0 c1 c2 c3 c4 c5 c6 c7 + // d: d0 d1 d2 d3 d4 d5 d6 d7 + // e: e0 e1 e2 e3 e4 e5 e6 e7 + // f: f0 f1 f2 f3 f4 f5 f6 f7 + // g: g0 g1 g2 g3 g4 g5 g6 g7 + // h: h0 h1 h2 h3 h4 h5 h6 h7 + int i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i < 8; i++) { + input[i] = _mm256_loadu_ps(&src[i * ld_src]); + } + + transpose_block(input); + + // store from registers to dst +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i < 8; i++) { + _mm256_storeu_ps(&dst[i * ld_dst], input[i]); + } +} + +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst); + transpose_mxn(src + 8, ld_src, dst + 8 * ld_dst, ld_dst); + transpose_mxn(src + 8 * ld_src, ld_src, dst + 8, ld_dst); + transpose_mxn( + src + 8 * ld_src + 8, ld_src, dst + 8 * ld_dst + 8, ld_dst); +} +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h new file mode 100644 index 0000000000000000000000000000000000000000..3022d265b398b7faadea69251ec38521d8a77117 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h @@ -0,0 +1,280 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX2 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = Half; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_si256(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_si256(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_si256(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + auto max_lo = _mm256_max_ps(a_lo, b_lo); + auto max_hi = _mm256_max_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(max_lo, nan_lo); + auto o2 = _mm256_or_ps(max_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + auto min_lo = _mm256_min_ps(a_lo, b_lo); + auto min_hi = _mm256_min_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(min_lo, nan_lo); + auto o2 = _mm256_or_ps(min_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + __m256 max_lo, max_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(min), min_lo, min_hi); + cvtfp16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo)); + auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi)); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 max_lo, max_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, a_lo); + auto o2 = _mm256_min_ps(max_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(min), min_lo, min_hi); + auto o1 = _mm256_max_ps(min_lo, a_lo); + auto o2 = _mm256_max_ps(min_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +inline void convert(const Half* src, Half* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, Half* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = _mm256_loadu_ps(&src[i]); + __m256 b = _mm256_loadu_ps(&src[i + 8]); + + __m256i c = cvtfp32_fp16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, Half* dst, int64_t n) { + auto load_float = [](const double* src) -> __m256 { + // Load one float vector from an array of doubles + __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src)); + __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4)); + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = load_float(&src[i]); + __m256 b = load_float(&src[i + 8]); + + __m256i c = cvtfp32_fp16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + __m256 c_lo, c_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + cvtfp16_fp32(__m256i(c), c_lo, c_hi); + auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_fp16(o1, o2); +} + +CONVERT_VECTORIZED_INIT(Half, half) +LOAD_FP32_VECTORIZED_INIT(Half, fp16) + +#else // defined(CPU_CAPABILITY_AVX2) + +#if !( \ + defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE256)) +CONVERT_NON_VECTORIZED_INIT(Half, half) +#endif + +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) +#endif // defined(CPU_CAPABILITY_AVX2) +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h new file mode 100644 index 0000000000000000000000000000000000000000..1ff76d4d2789eb1ed222dbae83c2437df14ac608 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h @@ -0,0 +1,2318 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX2 + +struct Vectorizedi { + protected: + __m256i values; + + static inline __m256i invert(const __m256i& v) { + const auto ones = _mm256_set1_epi64x(-1); + return _mm256_xor_si256(ones, v); + } + + public: + Vectorizedi() {} + Vectorizedi(__m256i v) : values(v) {} + operator __m256i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX2 + +#ifdef CPU_CAPABILITY_AVX2 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 4; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int64_t v) { + values = _mm256_set1_epi64x(v); + } + Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) { + values = _mm256_setr_epi64x(val1, val2, val3, val4); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + __at_align__ int64_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi64(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi64(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi64(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi64(b.values, 3); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int64_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ int64_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int64_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int64_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto zero = _mm256_set1_epi64x(0); + auto is_larger = _mm256_cmpgt_epi64(zero, values); + auto inverse = _mm256_xor_si256(values, is_larger); + return _mm256_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi64x(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi64(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi64(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi64(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi64(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi64(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi64(other.values, values)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int32_t; + static constexpr int size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { + values = _mm256_set1_epi32(v); + } + Vectorized( + int32_t val1, + int32_t val2, + int32_t val3, + int32_t val4, + int32_t val5, + int32_t val6, + int32_t val7, + int32_t val8) { + values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm256_blend_epi32(a, b, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int32_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + __at_align__ int32_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int32_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int32_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm256_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + int32_t reduce_add() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_si256(v, v, 0x1); + v = _mm256_add_epi32(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0x4E); + v = _mm256_add_epi32(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0xB1); + v = _mm256_add_epi32(v, v1); + __m128i lo = _mm256_castsi256_si128(v); + return _mm_cvtsi128_si32(lo); + } + int32_t reduce_max() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_si256(v, v, 0x1); + v = _mm256_max_epi32(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0x4E); + v = _mm256_max_epi32(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0xB1); + v = _mm256_max_epi32(v, v1); + __m128i lo = _mm256_castsi256_si128(v); + return _mm_cvtsi128_si32(lo); + } + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi32(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi32(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi32(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi32(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi32(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi32(other.values, values)); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_vec = + _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm256_cvtepi32_ps(input_vec); + _mm256_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, double* dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_128_vec = + _mm_loadu_si128(reinterpret_cast(src + i)); + auto output_vec = _mm256_cvtepi32_pd(input_128_vec); + _mm256_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int16_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { + values = _mm256_set1_epi16(v); + } + Vectorized( + int16_t val1, + int16_t val2, + int16_t val3, + int16_t val4, + int16_t val5, + int16_t val6, + int16_t val7, + int16_t val8, + int16_t val9, + int16_t val10, + int16_t val11, + int16_t val12, + int16_t val13, + int16_t val14, + int16_t val15, + int16_t val16) { + values = _mm256_setr_epi16( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi16(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi16(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi16(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi16(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi16(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi16(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi16(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi16(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi16(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi16(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi16(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi16(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi16(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi16(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi16(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi16(b.values, 15); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + __at_align__ int16_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm256_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi16(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi16(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi16(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi16(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi16(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi16(other.values, values)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + + protected: + static const Vectorized ones; + + public: + using value_type = T; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized8() {} + Vectorized8(T v) { + values = _mm256_set1_epi8(v); + } + Vectorized8( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32) { + values = _mm256_setr_epi8( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16, + val17, + val18, + val19, + val20, + val21, + val22, + val23, + val24, + val25, + val26, + val27, + val28, + val29, + val30, + val31, + val32); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + __at_align__ T tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi8(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi8(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi8(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi8(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi8(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi8(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi8(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi8(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi8(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi8(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi8(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi8(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi8(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi8(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi8(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi8(b.values, 15); + if (mask & 0x010000) + tmp_values[16] = _mm256_extract_epi8(b.values, 16); + if (mask & 0x020000) + tmp_values[17] = _mm256_extract_epi8(b.values, 17); + if (mask & 0x040000) + tmp_values[18] = _mm256_extract_epi8(b.values, 18); + if (mask & 0x080000) + tmp_values[19] = _mm256_extract_epi8(b.values, 19); + if (mask & 0x100000) + tmp_values[20] = _mm256_extract_epi8(b.values, 20); + if (mask & 0x200000) + tmp_values[21] = _mm256_extract_epi8(b.values, 21); + if (mask & 0x400000) + tmp_values[22] = _mm256_extract_epi8(b.values, 22); + if (mask & 0x800000) + tmp_values[23] = _mm256_extract_epi8(b.values, 23); + if (mask & 0x1000000) + tmp_values[24] = _mm256_extract_epi8(b.values, 24); + if (mask & 0x2000000) + tmp_values[25] = _mm256_extract_epi8(b.values, 25); + if (mask & 0x4000000) + tmp_values[26] = _mm256_extract_epi8(b.values, 26); + if (mask & 0x8000000) + tmp_values[27] = _mm256_extract_epi8(b.values, 27); + if (mask & 0x10000000) + tmp_values[28] = _mm256_extract_epi8(b.values, 28); + if (mask & 0x20000000) + tmp_values[29] = _mm256_extract_epi8(b.values, 29); + if (mask & 0x40000000) + tmp_values[30] = _mm256_extract_epi8(b.values, 30); + if (mask & 0x80000000) + tmp_values[31] = _mm256_extract_epi8(b.values, 31); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set(Vectorized a, Vectorized b, T count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu_one_fourth(const void* ptr) { + // Fast path if only load element number of 8. + // Note: We didn't merge it as fast path of loadu(const void* ptr, T count), + // Because loadu(const void* ptr, T count) requires zero initialization for + // upper 128 bits. However, by using _mm256_castsi128_si256, the upper 128 + // bits of the result are undefined. + // TODO We can use _mm256_zextsi128_si256 in the furture, + // since gcc 9.3 doesn't support it now. + __m128i input_128 = _mm_loadl_epi64(reinterpret_cast(ptr)); + return _mm256_castsi128_si256(input_128); + } + static Vectorized loadu(const void* ptr, T count) { + __at_align__ T tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(T)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + if (count == 8) { + // Fast path if only store element number of 8 + _mm_storel_epi64( + reinterpret_cast<__m128i*>(ptr), _mm256_castsi256_si128(values)); + } else { + __at_align__ T tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(T)); + } + } + } + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + Vectorized neg() const; + + Vectorized abs() const { + return _mm256_abs_epi8(values); + } + + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi8(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi8(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi8(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi8(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi8(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi8(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return invert(_mm256_cmpeq_epi8(max, values)); + } + Vectorized operator<=(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return _mm256_cmpeq_epi8(max, other.values); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi64(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi16(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi8(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +// Emulate operations with no native 64-bit support in avx, +// by extracting each element, performing the operation pointwise, +// then combining the results into a vector. +template +Vectorized inline emulate( + const Vectorized& a, + const Vectorized& b, + const op_t& op) { + int64_t a0 = _mm256_extract_epi64(a, 0); + int64_t a1 = _mm256_extract_epi64(a, 1); + int64_t a2 = _mm256_extract_epi64(a, 2); + int64_t a3 = _mm256_extract_epi64(a, 3); + + int64_t b0 = _mm256_extract_epi64(b, 0); + int64_t b1 = _mm256_extract_epi64(b, 1); + int64_t b2 = _mm256_extract_epi64(b, 2); + int64_t b3 = _mm256_extract_epi64(b, 3); + + int64_t c0 = op(a0, b0); + int64_t c1 = op(a1, b1); + int64_t c2 = op(a2, b2); + int64_t c3 = op(a3, b3); + + return _mm256_set_epi64x(c3, c2, c1, c0); +} + +template +Vectorized inline emulate( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c, + const op_t& op) { + int64_t a0 = _mm256_extract_epi64(a, 0); + int64_t a1 = _mm256_extract_epi64(a, 1); + int64_t a2 = _mm256_extract_epi64(a, 2); + int64_t a3 = _mm256_extract_epi64(a, 3); + + int64_t b0 = _mm256_extract_epi64(b, 0); + int64_t b1 = _mm256_extract_epi64(b, 1); + int64_t b2 = _mm256_extract_epi64(b, 2); + int64_t b3 = _mm256_extract_epi64(b, 3); + + int64_t c0 = _mm256_extract_epi64(c, 0); + int64_t c1 = _mm256_extract_epi64(c, 1); + int64_t c2 = _mm256_extract_epi64(c, 2); + int64_t c3 = _mm256_extract_epi64(c, 3); + + int64_t d0 = op(a0, b0, c0); + int64_t d1 = op(a1, b1, c1); + int64_t d2 = op(a2, b2, c2); + int64_t d3 = op(a3, b3, c3); + + return _mm256_set_epi64x(d3, d2, d1, d0); +} + +// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated +// This could be implemented more efficiently using epi32 instructions +// This is also technically avx compatible, but then we'll need AVX +// code for add as well. +// Note: intentionally ignores undefined behavior like (-lowest * -1). +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return emulate( + a, b, [](int64_t a_point, int64_t b_point) __ubsan_ignore_undefined__ { + return a_point * b_point; + }); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_256( + const Vectorized& a, + const Vectorized& b, + Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying int8_t +#ifndef CPU_CAPABILITY_AVX2 + return int_elementwise_binary_256(a, b, std::multiplies()); +#else + __m256i mask00FF = _mm256_set1_epi16(0x00FF); + __m256i a_lo = _mm256_srai_epi16(_mm256_slli_epi16(a, 8), 8); + __m256i b_lo = _mm256_srai_epi16(_mm256_slli_epi16(b, 8), 8); + __m256i a_hi = _mm256_srai_epi16(a, 8); + __m256i b_hi = _mm256_srai_epi16(b, 8); + __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF); + __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8); + __m256i res = _mm256_or_si256(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t +#ifndef CPU_CAPABILITY_AVX2 + return int_elementwise_binary_256(a, b, std::multiplies()); +#else + __m256i mask00FF = _mm256_set1_epi16(0x00FF); + __m256i a_lo = _mm256_and_si256(a, mask00FF); + __m256i b_lo = _mm256_and_si256(b, mask00FF); + __m256i a_hi = _mm256_srli_epi16(a, 8); + __m256i b_hi = _mm256_srli_epi16(b, 8); + __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF); + __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8); + __m256i res = _mm256_or_si256(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, b, [](int64_t a_point, int64_t b_point) { + return std::min(a_point, b_point); + }); +#else + __m256i cmp = _mm256_cmpgt_epi64(a, b); + return _mm256_blendv_epi8(a, b, cmp); +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi32(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi16(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi8(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epu8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, b, [](int64_t a_point, int64_t b_point) { + return std::max(a_point, b_point); + }); +#else + __m256i cmp = _mm256_cmpgt_epi64(a, b); + return _mm256_blendv_epi8(b, a, cmp); +#endif +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi32(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi16(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epu8(a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate( + a, + min_val, + max_val, + [](int64_t a_point, int64_t min_point, int64_t max_point) { + return std::min(max_point, std::max(a_point, min_point)); + }); +#else + return minimum(maximum(a, min_val), max_val); +#endif +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, max_val, [](int64_t a_point, int64_t max_point) { + return std::min(max_point, a_point); + }); +#else + return minimum(max_val, a); +#endif +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epu8(max_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, min_val, [](int64_t a_point, int64_t min_point) { + return std::max(min_point, a_point); + }); +#else + return maximum(min_val, a); +#endif +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi8(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epu8(min_val, a); +} + +template +std::enable_if_t< + !(std::is_same_v || std::is_same_v), + Vectorized< + int32_t>> inline convert_to_int32(const T* ptr, int count = Vectorized::size()) { + return Vectorized::loadu(ptr, count); +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const int8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm256_cvtepi8_epi32( + _mm_loadl_epi64(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm256_cvtepi8_epi32(_mm256_castsi256_si128(a)); + } +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const uint8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm256_cvtepu8_epi32(_mm256_castsi256_si128(a)); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} + +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm256_and_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm256_or_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm256_xor_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm256_xor_si256(a, _mm256_set1_epi32(-1)); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template +Vectorized inline shift_256_16( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int16_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 16-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m256i ctl_0_1 = _mm256_set_epi8( + 29, + 28, + 0x80, + 0x80, + 25, + 24, + 0x80, + 0x80, + 21, + 20, + 0x80, + 0x80, + 17, + 16, + 0x80, + 0x80, + 13, + 12, + 0x80, + 0x80, + 9, + 8, + 0x80, + 0x80, + 5, + 4, + 0x80, + 0x80, + 1, + 0, + 0x80, + 0x80); + __m256i ctl_1_0 = _mm256_set_epi8( + 0x80, + 0x80, + 31, + 30, + 0x80, + 0x80, + 27, + 26, + 0x80, + 0x80, + 23, + 22, + 0x80, + 0x80, + 19, + 18, + 0x80, + 0x80, + 15, + 14, + 0x80, + 0x80, + 11, + 10, + 0x80, + 0x80, + 7, + 6, + 0x80, + 0x80, + 3, + 2); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 16-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFFFF); + __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000); + + // Take each 16-bit element with idx%2==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 16 + // bits will be proper result of shifting original 16-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else + c0 = _mm256_srav_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m256i a1 = _mm256_and_si256(a, keep_1); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else + c1 = _mm256_srav_epi32(a1, b1); + c1 = _mm256_and_si256(c1, keep_1); + + // Merge partial results into the final result. + __m256i c = _mm256_or_si256(c0, c1); + + return c; +} + +template < + bool left_shift, + typename T, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + int> = 0> +Vectorized inline shift_256_8( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 8-bit elements, and considering quadruples of + // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N + // in [0,1,2,3], and M!=N) is set so that shuffle will move element + // with index M from input quadruple into element with index N in + // output quadruple, and other elements in output quadruple will be + // set to all 0s. + __m256i ctl_0_3 = _mm256_set_epi8( + 28, + 0x80, + 0x80, + 0x80, + 24, + 0x80, + 0x80, + 0x80, + 20, + 0x80, + 0x80, + 0x80, + 16, + 0x80, + 0x80, + 0x80, + 12, + 0x80, + 0x80, + 0x80, + 8, + 0x80, + 0x80, + 0x80, + 4, + 0x80, + 0x80, + 0x80, + 0, + 0x80, + 0x80, + 0x80); + __m256i ctl_1_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 29, + 0x80, + 0x80, + 0x80, + 25, + 0x80, + 0x80, + 0x80, + 21, + 0x80, + 0x80, + 0x80, + 17, + 0x80, + 0x80, + 0x80, + 13, + 0x80, + 0x80, + 0x80, + 9, + 0x80, + 0x80, + 0x80, + 5, + 0x80, + 0x80, + 0x80, + 1); + __m256i ctl_1_3 = _mm256_set_epi8( + 29, + 0x80, + 0x80, + 0x80, + 25, + 0x80, + 0x80, + 0x80, + 21, + 0x80, + 0x80, + 0x80, + 17, + 0x80, + 0x80, + 0x80, + 13, + 0x80, + 0x80, + 0x80, + 9, + 0x80, + 0x80, + 0x80, + 5, + 0x80, + 0x80, + 0x80, + 1, + 0x80, + 0x80, + 0x80); + __m256i ctl_2_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 30, + 0x80, + 0x80, + 0x80, + 26, + 0x80, + 0x80, + 0x80, + 22, + 0x80, + 0x80, + 0x80, + 18, + 0x80, + 0x80, + 0x80, + 14, + 0x80, + 0x80, + 0x80, + 10, + 0x80, + 0x80, + 0x80, + 6, + 0x80, + 0x80, + 0x80, + 2); + __m256i ctl_2_3 = _mm256_set_epi8( + 30, + 0x80, + 0x80, + 0x80, + 26, + 0x80, + 0x80, + 0x80, + 22, + 0x80, + 0x80, + 0x80, + 18, + 0x80, + 0x80, + 0x80, + 14, + 0x80, + 0x80, + 0x80, + 10, + 0x80, + 0x80, + 0x80, + 6, + 0x80, + 0x80, + 0x80, + 2, + 0x80, + 0x80, + 0x80); + __m256i ctl_3_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3); + __m256i ctl_3_1 = _mm256_set_epi8( + 0x80, + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3, + 0x80); + __m256i ctl_3_2 = _mm256_set_epi8( + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3, + 0x80, + 0x80); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 8-bit elements, and considering them in quadruples of neighboring + // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that + // bitwise and will copy element with index M from input quadruple + // into element with the same index in output quadruple, while the + // other elements in output quadruple will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFF); + __m256i keep_3 = _mm256_set1_epi32(0xFF000000); + + // Take each 8-bit element with idx%4==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%4!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else if constexpr (std::is_same_v) + c0 = _mm256_srav_epi32(a0, b0); + else + c0 = _mm256_srlv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_3_0); + + // Peform shifting the same way for input array elements with + // idx%4==1. + __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else if constexpr (std::is_same_v) + c1 = _mm256_srav_epi32(a1, b1); + else + c1 = _mm256_srlv_epi32(a1, b1); + c1 = _mm256_shuffle_epi8(c1, ctl_3_1); + + // Peform shifting the same way for input array elements with + // idx%4==2. + __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3); + __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0); + __m256i c2; + if (left_shift) + c2 = _mm256_sllv_epi32(a2, b2); + else if constexpr (std::is_same_v) + c2 = _mm256_srav_epi32(a2, b2); + else + c2 = _mm256_srlv_epi32(a2, b2); + c2 = _mm256_shuffle_epi8(c2, ctl_3_2); + + // Peform shifting the same way for input array elements with + // idx%4==3. + __m256i a3 = _mm256_and_si256(a, keep_3); + __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0); + __m256i c3; + if (left_shift) + c3 = _mm256_sllv_epi32(a3, b3); + else if constexpr (std::is_same_v) + c3 = _mm256_srav_epi32(a3, b3); + else + c3 = _mm256_srlv_epi32(a3, b3); + c3 = _mm256_and_si256(c3, keep_3); + + // Merge partial results into the final result. + __m256i c01 = _mm256_or_si256(c0, c1); + __m256i c23 = _mm256_or_si256(c2, c3); + __m256i c = _mm256_or_si256(c01, c23); + + return c; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for right arithmetic shifting int64_t, so emulating + // it instead. + + // Clamp the shift values such that shift values < 0 and > 64 are changed to + // 64 which results in -1 for negative input and 0 for non-negative input. + __m256i zero = _mm256_set1_epi64x(0); + __m256i max_shift = _mm256_set1_epi64x(64); + __m256i mask = _mm256_or_si256( + _mm256_cmpgt_epi64(zero, b), _mm256_cmpgt_epi64(b, max_shift)); + __m256i shift = _mm256_blendv_epi8(b, max_shift, mask); + // Shift the number logically to the right, thus filling the most + // significant bits with 0s. Then, replace these bits with the sign + // bit. + __m256i sign_bits = _mm256_cmpgt_epi64(zero, a); + __m256i sign_shift = _mm256_sub_epi64(max_shift, shift); + __m256i sign_ext = _mm256_sllv_epi64(sign_bits, sign_shift); + __m256i c = _mm256_srlv_epi64(a, shift); + c = _mm256_or_si256(c, sign_ext); + + return c; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm256_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..3460abe17e159d821d51c421e120117126761434 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h @@ -0,0 +1,298 @@ +#pragma once + +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + 2, + mask_t, + 1, + typename std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + auto int64_mask = vec_mask.template cast(); + auto result = at::vec::VectorizedN(); + if constexpr (std::is_same_v) { + result[0] = _mm256_maskload_pd(ptr, int64_mask[0]); + result[1] = _mm256_maskload_pd( + ptr + at::vec::Vectorized::size(), int64_mask[1]); + } else { + result[0] = _mm256_maskload_epi64( + reinterpret_cast(ptr), int64_mask[0]); + result[1] = _mm256_maskload_epi64( + reinterpret_cast( + ptr + at::vec::Vectorized::size()), + int64_mask[1]); + } + return result; + } +}; + +// TODO: add specialization of VecMaskLoad for bfloat16/half and int8/uint8 + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castps_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castpd_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +inline bool VecMask::all_zero() const { + return _mm256_testz_si256(mask_[0], mask_[0]); +} + +template <> +inline bool VecMask::is_masked(int i) const { + return _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])) & (1 << i); +} + +template <> +inline bool VecMask::all_masked() const { + int mask = _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])); + return mask == 0xff; +} + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = all_zero && (_mm256_testz_si256(vec_mask[i], vec_mask[i]) > 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 4) { + return _mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[j])) & + (1 << (i - j * 4)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && + (_mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[i])) == 0x0f); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + +#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ + T, N, return_type, method, args_def, args) \ + template <> \ + inline return_type VecMask::method args_def const { \ + return cast().method args; \ + } + +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) + +#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..399cf206ab09ec8f7e01860ca1f63da79d1ec5c2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h @@ -0,0 +1,1397 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +#ifdef _MSC_VER +__declspec(align(64)) struct Vectorizedqi { + protected: + __m256i vals; +#else +struct Vectorizedqi { + protected: + __m256i vals __attribute__((aligned(64))); +#endif + + public: + Vectorizedqi() {} + Vectorizedqi(__m256i v) : vals(v) {} + operator __m256i() const { + return vals; + } +}; + +template +__m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + T min_val, + T max_val); + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i /*first*/, + __m256i /*second*/, + int32_t /*min_val*/, + int32_t /*max_val*/) { + // This function is for linkage only, will not be used + TORCH_CHECK(false, "pack_saturate_and_clamp is not supported"); +} + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + int8_t min_val, + int8_t max_val) { + __m256i packed_and_sat = _mm256_packs_epi16(first, second); + return _mm256_max_epi8( + _mm256_set1_epi8(min_val), + _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val))); +} + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + uint8_t min_val, + uint8_t max_val) { + __m256i packed_and_sat = _mm256_packus_epi16(first, second); + return _mm256_max_epu8( + _mm256_set1_epi8(min_val), + _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val))); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(at::vec::Vectorized src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 8*8 bits + __m128i input_128 = _mm256_castsi256_si128(src); + // Convert from 8*uint8/int8 to 8*int32 + __m256i input_256_int32; + if constexpr (std::is_same_v) + input_256_int32 = _mm256_cvtepu8_epi32(input_128); + else + input_256_int32 = _mm256_cvtepi8_epi32(input_128); + // Convert from 8*int32 to 8*float + return _mm256_cvtepi32_ps(input_256_int32); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(at::vec::Vectorized src) { + // Convert from float32 to int32 with truncation + __m256i x_values_int32 = _mm256_cvttps_epi32(src); + + // Convert from int32 to int16 using signed saturation + __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32); + + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to uint8/int8 using unsigned saturation + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); +} + +template +__FORCE_INLINE void QuantizeAvx2( + const float* src, + T* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 8; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m256i min_v = _mm256_set1_epi32(min_val); + const __m256i max_v = _mm256_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + // clang-format off + static const __m256i shuffle_mask_v = _mm256_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + __m256i permute_mask_l8_v = + _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m256 x_vals = _mm256_load_ps(src + i); + __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // y + __m256 y_vals = _mm256_load_ps(src + i + VLEN); + __m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // z + __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN); + __m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // w + __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN); + __m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val)); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point)); + y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point)); + z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point)); + w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point)); + + __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX2 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m256 x_vals = _mm256_load_ps(src + i); + __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val)); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point)); + __m256i x_clipped_v = + _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + i), + _mm256_castsi256_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + std::nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type kSize = Vectorized::size(); + static constexpr size_type size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_zp_premul) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + Vectorized retval; + auto rhs_data = (__m256)rhs[0]; + at::native::quantize_vec( + scale, + zero_point, + (float*)&rhs_data, + (c10::qint32*)&retval.vals, + size()); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epi32( + _mm256_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm256_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + + __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v); + __m256i rounded = _mm256_cvtps_epi32(scaled); + return _mm256_add_epi32(rounded, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m256i RequantizeAvx2( + const std::array, 4>& inp, + __m256 multiplier, + __m256i zp) { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier); + __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[1]), multiplier); + __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[2]), multiplier); + __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[3]), multiplier); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m256i x_v = _mm256_add_epi32(x_rounded_v, zp); + __m256i y_v = _mm256_add_epi32(y_rounded_v, zp); + __m256i z_v = _mm256_add_epi32(z_rounded_v, zp); + __m256i w_v = _mm256_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m256i xy_packed_v = _mm256_packs_epi32(x_v, y_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_v, w_v); + + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 + */ + xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + return xyzw_clamped_v; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; + static constexpr int size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int kIntNumVecs = kSize / Vectorized::size(); + static constexpr int int_num_vecs() { + return kIntNumVecs; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; + using value_type = typename c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy) + C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy") + C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy") +#endif + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + C10_CLANG_DIAGNOSTIC_POP() + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + private: + __m256i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm256_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_neg_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float /*scale*/, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epi8(_mm256_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256i int32_val0 = cvtepi8_epi32(int_val0); + __m256i int32_val1 = cvtepi8_epi32(int_val1); + __m256i int32_val2 = cvtepi8_epi32(int_val2); + __m256i int32_val3 = cvtepi8_epi32(int_val3); + + __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0)); + __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1)); + __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2)); + __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3)); + + __m256i int32_b0 = cvtepi8_epi32(int_b0); + __m256i int32_b1 = cvtepi8_epi32(int_b1); + __m256i int32_b2 = cvtepi8_epi32(int_b2); + __m256i int32_b3 = cvtepi8_epi32(int_b3); + + __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0); + __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1); + __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2); + __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3); + + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + return RequantizeAvx2(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; + static constexpr int size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int kIntNumVecs = kSize / Vectorized::size(); + static constexpr int int_num_vecs() { + return kIntNumVecs; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; + using value_type = typename c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi8(uw); + } + + // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy) + C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy") + C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy") +#endif + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + C10_CLANG_DIAGNOSTIC_POP() + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + private: + __m256i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm256_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float /*scale*/, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epu8(_mm256_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256i int32_val0 = cvtepu8_epi32(int_val0); + __m256i int32_val1 = cvtepu8_epi32(int_val1); + __m256i int32_val2 = cvtepu8_epi32(int_val2); + __m256i int32_val3 = cvtepu8_epi32(int_val3); + + __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0)); + __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1)); + __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2)); + __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3)); + + __m256i int32_b0 = cvtepu8_epi32(int_b0); + __m256i int32_b1 = cvtepu8_epi32(int_b1); + __m256i int32_b2 = cvtepu8_epi32(int_b2); + __m256i int32_b3 = cvtepu8_epi32(int_b3); + + __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0); + __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1); + __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2); + __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3); + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + return RequantizeAvx2(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#elif !defined(CPU_CAPABILITY_SVE256) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with AVX2. This may not be an issue, because +// currently for quantization we assume the user has at least AVX512 +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (AVX2+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size_ / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size_ / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (const auto i : c10::irange(size())) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized /*scale_zp_premul*/) const { + float_vec_return_type rv; + for (const auto i : c10::irange(float_num_vecs())) { + float tmp_vals[Vectorized::size()]; + for (const auto j : c10::irange(Vectorized::size())) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], + zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (const auto i : c10::irange(size())) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = + std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // if defined(CPU_CAPABILITY_AVX2) + +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +std::pair, Vectorized> inline convert_int8_to_float( + at::vec::Vectorized src) { + auto s8x8 = vld1_s8(src.operator const int8_t*()); + auto s16x8 = vmovl_s8(s8x8); + + auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); + auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); + + return std::make_pair( + Vectorized(vcvtq_f32_s32(s32x4_lo)), + Vectorized(vcvtq_f32_s32(s32x4_hi))); +} + +std::pair, Vectorized> inline convert_int8_to_float( + at::vec::Vectorized src) { + auto u8x8 = vld1_u8(src.operator const uint8_t*()); + auto u16x8 = vmovl_u8(u8x8); + auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); + auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); + + return std::make_pair( + Vectorized(vcvtq_f32_u32(u32x4_lo)), + Vectorized(vcvtq_f32_u32(u32x4_hi))); +} + +Vectorized inline convert_int8_half_register_to_float( + at::vec::Vectorized src) { + auto s8x8 = vld1_s8(src.operator const int8_t*()); + auto s16x8 = vmovl_s8(s8x8); + + auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); + + return Vectorized(vcvtq_f32_s32(s32x4_lo)); +} + +Vectorized inline convert_int8_half_register_to_float( + at::vec::Vectorized src) { + auto u8x8 = vld1_u8(src.operator const uint8_t*()); + auto u16x8 = vmovl_u8(u8x8); + auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); + + return Vectorized(vcvtq_f32_u32(u32x4_lo)); +} + +#endif +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..7dbd82ca756987dc8429223152a0b95705e628e1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +} + +inline void load_fp32_from_bf16( + const c10::BFloat16* data, + Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_bf16( + const c10::BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_bf16(data, out1); + data += Vectorized::size(); + load_fp32_from_bf16(data, out2); +} + +inline void load_fp32_from_fp16(const c10::Half* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_fp16( + const c10::Half* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_fp16(data, out1); + data += Vectorized::size(); + load_fp32_from_fp16(data, out2); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..ec485860cd2fbfd2e7646cda9463686896e76cc7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include + +// Note: header order is important here +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace at { +namespace vec { + +inline namespace CPU_CAPABILITY { + +DEFINE_CLAMP_FUNCS(c10::quint8) +DEFINE_CLAMP_FUNCS(c10::qint8) +DEFINE_CLAMP_FUNCS(c10::qint32) +DEFINE_CLAMP_FUNCS(int16_t) +DEFINE_CLAMP_FUNCS(int32_t) +DEFINE_CLAMP_FUNCS(int64_t) +DEFINE_CLAMP_FUNCS(float) +DEFINE_CLAMP_FUNCS(double) + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + vec_madd(a.vec0(), b.vec0(), c.vec0()), + vec_madd(a.vec1(), b.vec1(), c.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + vint32 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint32 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat32 c0 = vec_float(input_vec0); + vfloat32 c1 = vec_float(input_vec1); + vec_vsx_st(c0, offset0, dst_a); + vec_vsx_st(c1, offset16, dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + vint64 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint64 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat64 c0 = vec_double(input_vec0); + vfloat64 c1 = vec_double(input_vec1); + vec_vsx_st(c0, offset0, reinterpret_cast(dst_a)); + vec_vsx_st(c1, offset16, reinterpret_cast(dst_a)); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} +// Generic implementation to fix compiler error +// TO-DO : Add optimized version for ppc64 +inline std::tuple, Vectorized> convert_half_float( + const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ Half arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_half( + const Vectorized& a, + const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ Half arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +}; + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0); + vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3); + vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0); + vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3); + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11}); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0); + vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0); + + vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3); + vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3); + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23}); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + + vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233}); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2}); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..a6a883e53b39b61b59dda668ca1bfe2972356afc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -0,0 +1,679 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +using ComplexDbl = c10::complex; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexDbl; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + using size_type = int; + static constexpr size_type size() { + return 2; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(ComplexDbl val) { + double real_value = val.real(); + double imag_value = val.imag(); + _vec0 = vfloat64{real_value, imag_value}; + _vec1 = vfloat64{real_value, imag_value}; + } + Vectorized(ComplexDbl val1, ComplexDbl val2) { + _vec0 = vfloat64{val1.real(), val1.imag()}; + _vec1 = vfloat64{val2.real(), val2.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return a; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return b; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static Vectorized C10_ALWAYS_INLINE + el_blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vectorized( + vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0)); + return { + vec_sel(a._vec0, b._vec0, mask_complex._vecb0), + vec_sel(a._vec1, b._vec1, mask_complex._vecb1)}; + } + + static Vectorized C10_ALWAYS_INLINE elwise_blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + ComplexDbl base = 0., + step_t step = static_cast(1)) { + return Vectorized(base, base + step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexDbl& operator[](int idx) const = delete; + ComplexDbl& operator[](int idx) = delete; + + Vectorized map(ComplexDbl (*const f)(ComplexDbl)) const { + __at_align__ ComplexDbl tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized map(ComplexDbl (*const f)(const ComplexDbl&)) const { + __at_align__ ComplexDbl tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized el_swapped() const { + vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2); + vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2); + return {v0, v1}; + } + + Vectorized el_madd( + const Vectorized& multiplier, + const Vectorized& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + Vectorized el_mergeo() const { + vfloat64 v0 = vec_splat(_vec0, 1); + vfloat64 v1 = vec_splat(_vec1, 1); + return {v0, v1}; + } + + Vectorized el_mergee() const { + vfloat64 v0 = vec_splat(_vec0, 0); + vfloat64 v1 = vec_splat(_vec1, 0); + return {v0, v1}; + } + + static Vectorized el_mergee( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergeh(first._vec0, second._vec0), + vec_mergeh(first._vec1, second._vec1)}; + } + + static Vectorized el_mergeo( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergel(first._vec0, second._vec0), + vec_mergel(first._vec1, second._vec1)}; + } + + Vectorized abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a; + } + + Vectorized abs_() const { + auto vi = el_mergeo(); + auto vr = el_mergee(); + return { + Sleef_hypotd2_u05vsx(vr._vec0, vi._vec0), + Sleef_hypotd2_u05vsx(vr._vec1, vi._vec1)}; + } + + Vectorized abs() const { + return abs_() & vd_real_mask; + } + + Vectorized angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_pd(values, 0x05); // b a + // return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + Vectorized ret; + ret._vec0[0] = std::atan2(_vec0[1], _vec0[0]); + ret._vec1[0] = std::atan2(_vec1[1], _vec1[0]); + return ret; + } + + Vectorized angle() const { + return angle_() & vd_real_mask; + } + + Vectorized real_() const { + return *this & vd_real_mask; + } + Vectorized real() const { + return *this & vd_real_mask; + } + Vectorized imag_() const { + return *this & vd_imag_mask; + } + Vectorized imag() const { + return imag_().el_swapped(); + } + + Vectorized conj_() const { + return *this ^ vd_isign_mask; + } + Vectorized conj() const { + return *this ^ vd_isign_mask; + } + + Vectorized log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(vd_log2e_inv); + } + Vectorized log10() const { + auto ret = log(); + return ret.elwise_mult(vd_log10e_inv); + } + + Vectorized log1p() const { + return map(std::log1p); + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub(val_2, val_2_swapped); + re = Vectorized(vd_one) - re; + auto root = el_blend<0x0A>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(vd_pi_2) - asin(); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized(vd_imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * vd_imag_half; // i/2*ln() + } + Vectorized atanh() const { + return map(std::atanh); + } + + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized neg() const { + auto z = Vectorized(vd_zero); + return z - *this; + } + Vectorized round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vectorized trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + Vectorized sqrt() const { + return map(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ vd_isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + static Vectorized horizontal_add( + Vectorized& first, + Vectorized& second) { + // Operates on individual floats, see _mm_hadd_ps + // {f0+f1, s0+s1, f2+f3, s2+s3, ...} + // i.e. it sums the re and im of each value and interleaves first and + // second: {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...} + return el_mergee(first, second) + el_mergeo(first, second); + } + + static Vectorized horizontal_sub( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vectorized inline operator*( + const Vectorized& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ vd_rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.elwise_mult(vi) + ret; +#else + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ vd_isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub(ac_bd, ad_bc); +#endif + return ret; + } + + Vectorized inline operator/( + const Vectorized& b) const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() + // auto fabs_cd = Vectorized{ + // vec_andc(b._vec0, vd_sign_mask), + // vec_andc(b._vec1, vd_sign_mask)}; // |c| |d| + // auto fabs_dc = fabs_cd.el_swapped(); // |d| |c| + // auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|) + // auto a2 = elwise_div(scale); // a/sc b/sc + // auto b2 = b.elwise_div(scale); // c/sc d/sc + // auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2 + // auto dc2 = b2.el_swapped(); // d/sc c/sc + // dc2 = dc2 ^ vd_rsign_mask; // -d/sc c/sc + // auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2 + // auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2 + // auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 + // (c^2+d^2)/sc^2 ret = ret.elwise_div(denom2); return ret; + + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + this->store(tmp1); + b.store(tmp2); + + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return loadu(out); + } + + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + + Vectorized pow(const Vectorized& exp) const { + __at_align__ ComplexDbl x_tmp[size()]; + __at_align__ ComplexDbl y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + Vectorized operator<(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator<=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator>(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator>=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized eq(const Vectorized& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & vd_one; + } + Vectorized ne(const Vectorized& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & vd_one; + } + + DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor) + // elementwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple) + DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vectorized::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vectorized::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + // (a + ib) * (c + id) = (ac - bd) + i(ad + bc) + // Split into real and imaginary parts + auto a_real = a.el_mergee(); // real part of a + auto a_imag = a.el_mergeo(); // imag part of a + auto b_real = b.el_mergee(); // real part of b + auto b_imag = b.el_mergeo(); // imag part of b + + // Compute components + auto ac = a_real.elwise_mult(b_real); // real*real + auto bd = a_imag.elwise_mult(b_imag); // imag*imag + + // Real part: ac - bd + auto real = ac - bd; + + auto ad = a_real.elwise_mult(b_imag); // real*imag + auto bc = a_imag.elwise_mult(b_real); // imag*real + + // Imag = ad + bc + auto imag = ad + bc; + + // Merge real and imaginary parts into vectors + __vector double v0 = vec_mergeh(real.vec0(), imag.vec0()); // [r0, i0] + __vector double v1 = vec_mergeh(real.vec1(), imag.vec1()); // [r1, i1] + + // Create the final result + auto result = Vectorized{v0, v1}; + return result; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() + // Take absolute values of real and imaginary parts of b + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return Vectorized::loadu(out); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..9acc79cdeb4c5adfd500e10bd97649e46e4b23b3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h @@ -0,0 +1,771 @@ + +#pragma once +#include +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +using ComplexFlt = c10::complex; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexFlt; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + using size_type = int; + + static constexpr size_type size() { + return 4; + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(ComplexFlt val) { + float real_value = val.real(); + float imag_value = val.imag(); + _vec0 = vfloat32{real_value, imag_value, real_value, imag_value}; + _vec1 = vfloat32{real_value, imag_value, real_value, imag_value}; + } + + Vectorized( + ComplexFlt val1, + ComplexFlt val2, + ComplexFlt val3, + ComplexFlt val4) { + _vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()}; + _vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + const vbool32 mask_2nd = VsxComplexMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static Vectorized C10_ALWAYS_INLINE + el_blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vectorized( + vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1)); + return { + vec_sel( + a._vec0, b._vec0, reinterpret_cast(mask_complex._vec0)), + vec_sel( + a._vec1, b._vec1, reinterpret_cast(mask_complex._vec1)), + }; + } + + static Vectorized elwise_blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, reinterpret_cast(mask._vec0)), + vec_sel(a._vec1, b._vec1, reinterpret_cast(mask._vec1)), + }; + } + + template + static Vectorized arange( + ComplexFlt base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + ComplexFlt(2) * step, + base + ComplexFlt(3) * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexFlt& operator[](int idx) const = delete; + ComplexFlt& operator[](int idx) = delete; + + Vectorized map(ComplexFlt (*const f)(ComplexFlt)) const { + __at_align__ ComplexFlt tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized map(ComplexFlt (*const f)(const ComplexFlt&)) const { + __at_align__ ComplexFlt tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + static Vectorized horizontal_add( + Vectorized& first, + Vectorized& second) { + // Operates on individual floats, see _mm_hadd_ps + // {f0+f1, s0+s1, f2+f3, s2+s3, ...} + // i.e. it sums the re and im of each value and interleaves first and + // second: {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...} + return el_mergee(first, second) + el_mergeo(first, second); + } + + static Vectorized horizontal_sub_permD8( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // sum + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vectorized abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a.el_mergee(); + } + + Vectorized abs_() const { + auto vi = el_mergeo(); + auto vr = el_mergee(); + return { + Sleef_hypotf4_u05vsx(vr._vec0, vi._vec0), + Sleef_hypotf4_u05vsx(vr._vec1, vi._vec1)}; + } + + Vectorized abs() const { + return abs_() & real_mask; + } + + Vectorized real_() const { + return *this & real_mask; + } + Vectorized real() const { + return *this & real_mask; + } + Vectorized imag_() const { + return *this & imag_mask; + } + Vectorized imag() const { + // we can use swap_mask or sldwi + auto ret = imag_(); + return { + vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)}; + } + + Vectorized conj_() const { + return *this ^ isign_mask; + } + Vectorized conj() const { + return *this ^ isign_mask; + } + + Vectorized log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(log2e_inv); + } + Vectorized log10() const { + auto ret = log(); + return ret.elwise_mult(log10e_inv); + } + + Vectorized log1p() const { + return map(std::log1p); + } + + Vectorized el_swapped() const { + vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask); + vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + Vectorized el_mergee() const { + // as mergee phased in , we can use vec_perm with mask + return {vec_mergee(_vecb0, _vecb0), vec_mergee(_vecb1, _vecb1)}; + } + + Vectorized el_mergeo() const { + // as mergeo phased in , we can use vec_perm with mask + return {vec_mergeo(_vecb0, _vecb0), vec_mergeo(_vecb1, _vecb1)}; + } + + Vectorized el_madd( + const Vectorized& multiplier, + const Vectorized& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + static Vectorized el_mergee( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergee(first._vecb0, second._vecb0), + vec_mergee(first._vecb1, second._vecb1)}; + } + + static Vectorized el_mergeo( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergeo(first._vecb0, second._vecb0), + vec_mergeo(first._vecb1, second._vecb1)}; + } + + Vectorized angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_ps(values, 0xB1); // b a + // return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + Vectorized ret; + for (int i = 0; i < 4; i += 2) { + ret._vec0[i] = std::atan2(_vec0[i + 1], _vec0[i]); + ret._vec1[i] = std::atan2(_vec1[i + 1], _vec1[i]); + } + return ret; + } + + Vectorized angle() const { + return angle_() & real_mask; + } + + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + Vectorized ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized neg() const { + auto z = Vectorized(zero); + return z - *this; + } + Vectorized round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + Vectorized sqrt() const { + return map(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized pow(const Vectorized& exp) const { + __at_align__ ComplexFlt x_tmp[size()]; + __at_align__ ComplexFlt y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized(imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * imag_half; // i/2*ln() + } + Vectorized atanh() const { + return map(std::atanh); + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(pi_2) - asin(); + } + + Vectorized inline operator*( + const Vectorized& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.elwise_mult(vi) + ret; + return ret; + +#else + + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub_permD8(ac_bd, ad_bc); + return ret; +#endif + } + + Vectorized inline operator/( + const Vectorized& b) const { +#if 1 + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + this->store(tmp1); + b.store(tmp2); + + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return loadu(out); +#else + auto fabs_cd = Vectorized{ + vec_andc(b._vec0, sign_mask), vec_andc(b._vec1, sign_mask)}; // |c| |d| + auto fabs_dc = fabs_cd.el_swapped(); // |d| |c| + auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|) + auto a2 = elwise_div(scale); // a/sc b/sc + auto b2 = b.elwise_div(scale); // c/sc d/sc + auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/s + auto dc2 = b2.el_swapped(); // d/sc c/sc + dc2 = dc2 ^ rsign_mask; // -d/sc c/sc + auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2 + auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2 + auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 + ret = ret.elwise_div(denom2); + return ret; +#endif + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + +#if 1 + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub_permD8(val_2, val_2_swapped); + re = Vectorized(one) - re; + auto root = el_blend<0xAA>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); +#else + return map(std::asin); +#endif + } + + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + + Vectorized eq(const Vectorized& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & one; + } + Vectorized ne(const Vectorized& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & one; + } + + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + Vectorized operator<(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator<=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator>(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator>=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor) + // elementwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple) + DEFINE_MEMBER_OP(elwise_max, ComplexFlt, vec_max) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vectorized::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vectorized::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + // (a + ib) * (c + id) = (ac - bd) + i(ad + bc) + // Split into real and imaginary parts + auto a_real = a.el_mergee(); // real part of a + auto a_imag = a.el_mergeo(); // imag part of a + auto b_real = b.el_mergee(); // real part of b + auto b_imag = b.el_mergeo(); // imag part of b + + auto b_imag_neg = b_imag ^ rsign_mask; + // Compute components + auto ac = a_real.elwise_mult(b_real); // real * real + auto bd = a_imag.elwise_mult(b_imag_neg); // imag * imag + auto ad = a_real.elwise_mult(b_imag); // real * imag + auto bc = a_imag.elwise_mult(b_real); // imag * real + + // Real = ac - bd (fix the negative bd part) + auto real = ac + bd; // Real part calculation + auto imag = ad + bc; // Imaginary part calculation + + // Step 1: Extract from real and imag + __vector float r0 = real.vec0(); // {r0, r1, r2, r3} + __vector float i0 = imag.vec0(); // {i0, i1, i2, i3} + + __vector float r1 = real.vec1(); // imag[0..3] + __vector float i1 = imag.vec1(); // imag[4..7] + + __vector unsigned char perm_lo = { + 0, + 1, + 2, + 3, // r0 + 16, + 17, + 18, + 19, // + 8, + 9, + 10, + 11, // r1 + 24, + 25, + 26, + 27}; + __vector float v0 = + vec_perm(r0, i0, perm_lo); // Interleave r0 and i0, r1 and i1 + __vector float v1 = vec_perm(r1, i1, perm_lo); + Vectorized result(v0, v1); + return result; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + // Take absolute values of real and imaginary parts of b + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : + c10::irange(Vectorized>:: + size())) { //{Vectorized>::size())) + //{ + out[i] = tmp1[i] / tmp2[i]; + } + return Vectorized::loadu(out); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..57e28d133e8f39bb440e145d8adba15d0118f5a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -0,0 +1,512 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at { +namespace vec { + +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = double; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(double scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + double scalar1, + double scalar2, + double scalar3, + double scalar4) + : _vec0{vfloat64{scalar1, scalar2}}, _vec1{vfloat64{scalar3, scalar4}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + int zero_mask() const { + auto cmp = (*this == vd_zero); + return (cmp._vecb0[0] & 1) | (cmp._vecb0[1] & 2) | (cmp._vecb1[0] & 4) | + (cmp._vecb1[1] & 8); + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return {(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return {(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return {a._vec0, (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return {b._vec0, (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + + static Vectorized C10_ALWAYS_INLINE + set(const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + Vectorized map(double (*const f)(double)) const { + Vectorized ret; + for (const auto i : c10::irange(size() / 2)) { + ret._vec0[i] = f(_vec0[i]); + } + for (const auto i : c10::irange(size() / 2)) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vectorized mapbi( + double (*const f)(double, double), + const Vectorized& other) const { + Vectorized ret; + for (const auto i : c10::irange(size() / 2)) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (const auto i : c10::irange(size() / 2)) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE acos() const { + return {Sleef_acosd2_u10(_vec0), Sleef_acosd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE acosh() const { + return {Sleef_acoshd2_u10(_vec0), Sleef_acoshd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asin() const { + return {Sleef_asind2_u10(_vec0), Sleef_asind2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asinh() const { + return {Sleef_asinhd2_u10(_vec0), Sleef_asinhd2_u10(_vec1)}; + } + Vectorized atan() const { + return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)}; + } + Vectorized atanh() const { + return {Sleef_atanhd2_u10(_vec0), Sleef_atanhd2_u10(_vec1)}; + } + Vectorized atan2(const Vectorized& b) const { + return { + Sleef_atan2d2_u10(_vec0, b._vec0), Sleef_atan2d2_u10(_vec1, b._vec1)}; + } + Vectorized copysign(const Vectorized& sign) const { + return { + Sleef_copysignd2(_vec0, sign._vec0), + Sleef_copysignd2(_vec1, sign._vec1)}; + } + Vectorized erf() const { + return {Sleef_erfd2_u10(_vec0), Sleef_erfd2_u10(_vec1)}; + } + Vectorized erfc() const { + return {Sleef_erfcd2_u15(_vec0), Sleef_erfcd2_u15(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp() const { + return {Sleef_expd2_u10(_vec0), Sleef_expd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp2() const { + return {Sleef_exp2d2_u10(_vec0), Sleef_exp2d2_u10(_vec1)}; + } + Vectorized expm1() const { + return {Sleef_expm1d2_u10(_vec0), Sleef_expm1d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp_u20() const { + return exp(); + } + + Vectorized lgamma() const __ubsan_ignore_undefined__ { + return {Sleef_lgammad2_u10(_vec0), Sleef_lgammad2_u10(_vec1)}; + } + + Vectorized erfinv() const { + return map(calc_erfinv); + } + + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE log() const { + return {Sleef_logd2_u10(_vec0), Sleef_logd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log10() const { + return {Sleef_log10d2_u10(_vec0), Sleef_log10d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pd2_u10(_vec0), Sleef_log1pd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log2() const { + return {Sleef_log2d2_u10(_vec0), Sleef_log2d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cos() const { + return {Sleef_cosd2_u10(_vec0), Sleef_cosd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshd2_u10(_vec0), Sleef_coshd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sin() const { + return {Sleef_sind2_u10(_vec0), Sleef_sind2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhd2_u10(_vec0), Sleef_sinhd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tan() const { + return {Sleef_tand2_u10(_vec0), Sleef_tand2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhd2_u10(_vec0), Sleef_tanhd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return { + vec_div(vd_one, _vec0), // vec_re(_vec0) is estimated one. + vec_div(vd_one, _vec1)}; + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized C10_ALWAYS_INLINE pow(const Vectorized& b) const { + return {Sleef_powd2_u10(_vec0, b._vec0), Sleef_powd2_u10(_vec1, b._vec1)}; + } + Vectorized C10_ALWAYS_INLINE fmod(const Vectorized& b) const { + return {Sleef_fmodd2(_vec0, b._vec0), Sleef_fmodd2(_vec1, b._vec1)}; + } + + Vectorized hypot(const Vectorized& b) const { + return { + Sleef_hypotd2_u05(_vec0, b._vec0), Sleef_hypotd2_u05(_vec1, b._vec1)}; + } + + Vectorized nextafter(const Vectorized& b) const { + return { + Sleef_nextafterd2(_vec0, b._vec0), Sleef_nextafterd2(_vec1, b._vec1)}; + } + + Vectorized igamma(const Vectorized& x) const { + return mapbi(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapbi(calc_igammac, x); + } + + Vectorized i0() const { + return map(calc_i0); + } + + Vectorized i0e() const { + return map(calc_i0e); + } + + Vectorized digamma() const { + return map(calc_digamma); + } + + Vectorized _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + DEFINE_MEMBER_OP(operator==, double, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, double, vec_cmpne) + DEFINE_MEMBER_OP(operator<, double, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, double, vec_cmple) + DEFINE_MEMBER_OP(operator>, double, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, double, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, double, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, double, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, double, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, double, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, double, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, double, vec_cmpge) + DEFINE_MEMBER_OP(operator+, double, vec_add) + DEFINE_MEMBER_OP(operator-, double, vec_sub) + DEFINE_MEMBER_OP(operator*, double, vec_mul) + DEFINE_MEMBER_OP(operator/, double, vec_div) + DEFINE_MEMBER_OP(maximum, double, vec_max_nan2) + DEFINE_MEMBER_OP(minimum, double, vec_min_nan2) + DEFINE_MEMBER_OP(operator&, double, vec_and) + DEFINE_MEMBER_OP(operator|, double, vec_or) + DEFINE_MEMBER_OP(operator^, double, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, double, vec_madd) +}; +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..1a2f2b5ab7e850a60ae628bea71fa4110d89c1ff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -0,0 +1,545 @@ +#pragma once + +#include +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] + +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = float; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + using size_type = int; + + static constexpr size_type size() { + return 8; + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(float scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + float scalar1, + float scalar2, + float scalar3, + float scalar4, + float scalar5, + float scalar6, + float scalar7, + float scalar8) + : _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + + Vectorized map(float (*const f)(float)) const { + Vectorized ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vectorized mapbi( + float (*const f)(float, float), + const Vectorized& other) const { + Vectorized ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + + Vectorized _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + //__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + auto cmp = (*this == zero); + // return _mm256_movemask_ps(cmp); + // possible simulation //mask= lvsl ( 0 ) vbpermq( vec, mask <<5) + vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits); + vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits); + return (result0[1] >> 12 | (result1[1] >> 8)); + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE acos() const { + return {Sleef_acosf4_u10(_vec0), Sleef_acosf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE acosh() const { + return {Sleef_acoshf4_u10(_vec0), Sleef_acoshf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asin() const { + return {Sleef_asinf4_u10(_vec0), Sleef_asinf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asinh() const { + return {Sleef_asinhf4_u10(_vec0), Sleef_asinhf4_u10(_vec1)}; + } + Vectorized atan() const { + return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)}; + } + Vectorized atanh() const { + return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)}; + } + Vectorized atan2(const Vectorized& b) const { + return { + Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)}; + } + Vectorized copysign(const Vectorized& sign) const { + return { + Sleef_copysignf4(_vec0, sign._vec0), + Sleef_copysignf4(_vec1, sign._vec1)}; + } + Vectorized lgamma() const { + return {Sleef_lgammaf4_u10(_vec0), Sleef_lgammaf4_u10(_vec1)}; + } + Vectorized erf() const { + return {Sleef_erff4_u10(_vec0), Sleef_erff4_u10(_vec1)}; + } + + Vectorized erfc() const { + return {Sleef_erfcf4_u15(_vec0), Sleef_erfcf4_u15(_vec1)}; + } + + Vectorized erfinv() const { + return map(calc_erfinv); + } + + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE exp() const { + return {Sleef_expf4_u10(_vec0), Sleef_expf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp2() const { + return {Sleef_exp2f4_u10(_vec0), Sleef_exp2f4_u10(_vec1)}; + } + Vectorized expm1() const { + return {Sleef_expm1f4_u10(_vec0), Sleef_expm1f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp_u20() const { + return exp(); + } + + Vectorized C10_ALWAYS_INLINE log() const { + return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log10() const { + return {Sleef_log10f4_u10(_vec0), Sleef_log10f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pf4_u10(_vec0), Sleef_log1pf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log2() const { + return {Sleef_log2f4_u10(_vec0), Sleef_log2f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cos() const { + return {Sleef_cosf4_u10(_vec0), Sleef_cosf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshf4_u10(_vec0), Sleef_coshf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sin() const { + return {Sleef_sinf4_u10(_vec0), Sleef_sinf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhf4_u10(_vec0), Sleef_sinhf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tan() const { + return {Sleef_tanf4_u10(_vec0), Sleef_tanf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhf4_u10(_vec0), Sleef_tanhf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return Vectorized(one) / (*this); + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized C10_ALWAYS_INLINE pow(const Vectorized& exp) const { + return { + Sleef_powf4_u10(_vec0, exp._vec0), Sleef_powf4_u10(_vec1, exp._vec1)}; + } + + Vectorized fmod(const Vectorized& b) const { + return {Sleef_fmodf4(_vec0, b._vec0), Sleef_fmodf4(_vec1, b._vec1)}; + } + + Vectorized hypot(const Vectorized& b) const { + return { + Sleef_hypotf4_u05(_vec0, b._vec0), Sleef_hypotf4_u05(_vec1, b._vec1)}; + } + + Vectorized nextafter(const Vectorized& b) const { + return { + Sleef_nextafterf4(_vec0, b._vec0), Sleef_nextafterf4(_vec1, b._vec1)}; + } + + Vectorized igamma(const Vectorized& x) const { + return mapbi(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapbi(calc_igammac, x); + } + + Vectorized i0() const { + return map(calc_i0); + } + + Vectorized i0e() const { + return map(calc_i0e); + } + + Vectorized digamma() const { + return map(calc_digamma); + } + + DEFINE_MEMBER_OP(operator==, float, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, float, vec_cmpne) + DEFINE_MEMBER_OP(operator<, float, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, float, vec_cmple) + DEFINE_MEMBER_OP(operator>, float, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, float, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge) + DEFINE_MEMBER_OP(operator+, float, vec_add) + DEFINE_MEMBER_OP(operator-, float, vec_sub) + DEFINE_MEMBER_OP(operator*, float, vec_mul) + DEFINE_MEMBER_OP(operator/, float, vec_div) + DEFINE_MEMBER_OP(maximum, float, vec_max_nan2) + DEFINE_MEMBER_OP(minimum, float, vec_min_nan2) + DEFINE_MEMBER_OP(operator&, float, vec_and) + DEFINE_MEMBER_OP(operator|, float, vec_or) + DEFINE_MEMBER_OP(operator^, float, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..fe25c979906c4062a8b93a86836d705c879fb6d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -0,0 +1,435 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint16 _vec0; + vint16 _vec1; + }; + struct { + vbool16 _vecb0; + vbool16 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int16_t; + using vec_internal_type = vint16; + using vec_internal_mask_type = vbool16; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool16 v1, vbool16 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int16_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + + C10_ALWAYS_INLINE Vectorized( + int16_t scalar1, + int16_t scalar2, + int16_t scalar3, + int16_t scalar4, + int16_t scalar5, + int16_t scalar6, + int16_t scalar7, + int16_t scalar8, + int16_t scalar9, + int16_t scalar10, + int16_t scalar11, + int16_t scalar12, + int16_t scalar13, + int16_t scalar14, + int16_t scalar15, + int16_t scalar16) + : _vec0{vint16{ + scalar1, + scalar2, + scalar3, + scalar4, + scalar5, + scalar6, + scalar7, + scalar8}}, + _vec1{vint16{ + scalar9, + scalar10, + scalar11, + scalar12, + scalar13, + scalar14, + scalar15, + scalar16}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t<(mask & 65535) == 65535, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 255), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + + return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) && + ((mask & 255) != 255)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return { + (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), + (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not) + DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int16_t, vec_add) + DEFINE_MEMBER_OP(operator-, int16_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int16_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /) + DEFINE_MEMBER_OP(maximum, int16_t, vec_max) + DEFINE_MEMBER_OP(minimum, int16_t, vec_min) + DEFINE_MEMBER_OP(operator&, int16_t, vec_and) + DEFINE_MEMBER_OP(operator|, int16_t, vec_or) + DEFINE_MEMBER_OP(operator^, int16_t, vec_xor) +}; + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint16 shift_vec0 = reinterpret_cast(b.vec0()); + vuint16 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint16 shift_vec0 = reinterpret_cast(b.vec0()); + vuint16 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..c47b7ea66d74d272055715a349aa28ea3f6b14d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -0,0 +1,365 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int32_t; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int32_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + int32_t scalar1, + int32_t scalar2, + int32_t scalar3, + int32_t scalar4, + int32_t scalar5, + int32_t scalar6, + int32_t scalar7, + int32_t scalar8) + : _vec0{vint32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t<(mask & 255) == 255, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 15), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + + return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) && + ((mask & 15) != 15)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return { + (vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), + (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + int32_t base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not) + DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int32_t, vec_add) + DEFINE_MEMBER_OP(operator-, int32_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int32_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /) + DEFINE_MEMBER_OP(maximum, int32_t, vec_max) + DEFINE_MEMBER_OP(minimum, int32_t, vec_min) + DEFINE_MEMBER_OP(operator&, int32_t, vec_and) + DEFINE_MEMBER_OP(operator|, int32_t, vec_or) + DEFINE_MEMBER_OP(operator^, int32_t, vec_xor) +}; + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint32 shift_vec0 = reinterpret_cast(b.vec0()); + vuint32 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint32 shift_vec0 = reinterpret_cast(b.vec0()); + vuint32 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..2238f41aef300b3d2d3904e2db365118bbab9894 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -0,0 +1,319 @@ +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint64 _vec0; + vint64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int64_t; + using vec_internal_type = vint64; + using vec_internal_mask_type = vbool64; + using size_type = int; + using ElementType = signed long long; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int64_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + int64_t scalar1, + int64_t scalar2, + int64_t scalar3, + int64_t scalar4) + : _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask & 15) == 15, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t<(mask > 0 && mask < 3), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + const vbool64 mask_1st = (vbool64){g0, g1}; + return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1}; + } + + template + static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 3) && (mask & 3) != 0 && (mask & 15) != 15, + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_1st = (vbool64){g0, g1}; + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return { + (vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), + (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparision of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + int64_t base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + + static Vectorized C10_ALWAYS_INLINE + set(const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + static_assert(sizeof(double) == sizeof(value_type)); + const double* dptr = reinterpret_cast(ptr); + return {// treat it as double load + (vint64)vec_vsx_ld(offset0, dptr), + (vint64)vec_vsx_ld(offset16, dptr)}; + } + + __at_align__ double tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + (vint64)vec_vsx_ld(offset0, tmp_values), + (vint64)vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + double* dptr = reinterpret_cast(ptr); + vec_vsx_st((vfloat64)_vec0, offset0, dptr); + vec_vsx_st((vfloat64)_vec1, offset16, dptr); + } else if (count > 0) { + __at_align__ double tmp_values[size()]; + vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); + vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not) + DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int64_t, vec_add) + DEFINE_MEMBER_OP(operator-, int64_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int64_t, vec_mul) + DEFINE_MEMBER_OP(operator/, int64_t, vec_div) + DEFINE_MEMBER_OP(maximum, int64_t, vec_max) + DEFINE_MEMBER_OP(minimum, int64_t, vec_min) + DEFINE_MEMBER_OP(operator&, int64_t, vec_and) + DEFINE_MEMBER_OP(operator|, int64_t, vec_or) + DEFINE_MEMBER_OP(operator^, int64_t, vec_xor) +}; + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint64 shift_vec0 = reinterpret_cast(b.vec0()); + vuint64 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint64 shift_vec0 = reinterpret_cast(b.vec0()); + vuint64 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..ad895bf54d95a5b0b496f65ca7665dd2da2a15ba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -0,0 +1,301 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +struct Vectorized { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + + using size_type = int; + static constexpr size_type size() { + return 8; + } + + static constexpr size_t float_num_vecs() { + return 1; + } + static constexpr int int_num_vecs() { + return 1; + } + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(const c10::qint32& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_sub_zero_point_0 = vec_sub(float_vals0, zero_point_vec0); + vfloat32 vec_sub_zero_point_1 = vec_sub(float_vals1, zero_point_vec1); + Vectorized vf0 = { + vec_mul(scale_vec0, vec_sub_zero_point_0), + vec_mul(scale_vec1, vec_sub_zero_point_1)}; + return {vf0}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return {Vectorized{ + (float_vals0 - zero_point0) * scale_vec0, + (float_vals1 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized retval; + + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)(zero_point)); + Vectorized vf0 = rhs[0]; + + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + Vectorized relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) const { + vint32 max0 = vec_max(_vec0, zero_point._vec0); + vint32 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {*this - b}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 vec_mult = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + Vectorized vi = inp[0]; + vfloat32 vecf0 = vec_float(vi.vec0()); + vfloat32 vecf1 = vec_float(vi.vec1()); + + vecf0 = vec_mul(vecf0, vec_mult); + vecf1 = vec_mul(vecf1, vec_mult); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + + vint32 veci0 = vec_add(vec_signed(vecf0), vec_zero_point); + vint32 veci1 = vec_add(vec_signed(vecf1), vec_zero_point); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /) + DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..a707155aad787b7a1712163dac02fe0af3c6e0e4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h @@ -0,0 +1,512 @@ +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +struct Vectorized { + private: + union { + struct { + vint8 _vec0; + vint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + using size_type = int; + static constexpr size_type size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + using vec_internal_type = vint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vectorized(const c10::qint8& val) + : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {} + + C10_ALWAYS_INLINE Vectorized(const Vectorized& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vectorized(vint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vectorized loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_substract_src_zp0_0 = vec_sub(vecf0_0, zero_point_vec0); + vfloat32 vec_substract_src_zp1_0 = vec_sub(vecf1_0, zero_point_vec1); + Vectorized vf0_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_0), + vec_mul(scale_vec1, vec_substract_src_zp1_0)}; + + vfloat32 vec_substract_src_zp0_1 = vec_sub(vecf0_1, zero_point_vec0); + vfloat32 vec_substract_src_zp1_1 = vec_sub(vecf1_1, zero_point_vec1); + Vectorized vf1_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_1), + vec_mul(scale_vec1, vec_substract_src_zp1_1)}; + + vfloat32 vec_substract_src_zp0_2 = vec_sub(vecf0_2, zero_point_vec0); + vfloat32 vec_substract_src_zp1_2 = vec_sub(vecf1_2, zero_point_vec1); + Vectorized vf2_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_2), + vec_mul(scale_vec1, vec_substract_src_zp1_2)}; + + vfloat32 vec_substract_src_zp0_3 = vec_sub(vecf0_3, zero_point_vec0); + vfloat32 vec_substract_src_zp1_3 = vec_sub(vecf1_3, zero_point_vec1); + Vectorized vf3_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_3), + vec_mul(scale_vec1, vec_substract_src_zp1_3)}; + + return {vf0_zp, vf1_zp, vf2_zp, vf3_zp}; + } + + float_vec_return_type C10_ALWAYS_INLINE + dequantize(Vectorized scale, Vectorized zero_point) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return { + Vectorized{ + (vecf0_0 - zero_point0) * scale_vec0, + (vecf1_0 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_1 - zero_point0) * scale_vec0, + (vecf1_1 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_2 - zero_point0) * scale_vec0, + (vecf1_2 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_3 - zero_point0) * scale_vec0, + (vecf1_3 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vint32 vmin = vec_splats(min_val); + // vint32 vmax = vec_splats(max_val); + + Vectorized vf0 = rhs[0]; + Vectorized vf1 = rhs[1]; + Vectorized vf2 = rhs[2]; + Vectorized vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf2 = vec_mul(vecf2, inverse_scale_v); + vecf3 = vec_mul(vecf3, inverse_scale_v); + + vecf4 = vec_mul(vecf4, inverse_scale_v); + vecf5 = vec_mul(vecf5, inverse_scale_v); + vecf6 = vec_mul(vecf6, inverse_scale_v); + vecf7 = vec_mul(vecf7, inverse_scale_v); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ; + // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ; + // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ; + // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ; + + // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ; + // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ; + // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ; + // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ; + // vec_packs CLAMP already + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vectorized C10_ALWAYS_INLINE + relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized C10_ALWAYS_INLINE + relu6(Vectorized zero_point, Vectorized q_six) const { + vint8 max0 = vec_max(_vec0, zero_point._vec0); + vint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecBshi0 = vec_unpackh(b._vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + vint16 vecBshi1 = vec_unpackl(b._vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecBshi2 = vec_unpackh(b._vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + vint16 vecBshi3 = vec_unpackl(b._vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vectorized(veci0 - vecBi0, veci1 - vecBi1), + Vectorized(veci2 - vecBi2, veci3 - vecBi3), + Vectorized(veci4 - vecBi4, veci5 - vecBi5), + Vectorized(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /) + DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..5863df6bd667c91ea51992e6d5e747ffc36b885f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h @@ -0,0 +1,533 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +const vint16 mask_unsigned = vec_splats((short int)0xFF); +template <> +struct Vectorized { + private: + union { + struct { + vuint8 _vec0; + vuint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + using size_type = int; + static constexpr size_type size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + using vec_internal_type = vuint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vectorized(const c10::quint8& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + C10_ALWAYS_INLINE Vectorized(const Vectorized& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vectorized(vuint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vectorized loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_substract_src_zp0_0 = vec_sub(vecf0_0, zero_point_vec0); + vfloat32 vec_substract_src_zp1_0 = vec_sub(vecf1_0, zero_point_vec1); + Vectorized vf0_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_0), + vec_mul(scale_vec1, vec_substract_src_zp1_0)}; + + vfloat32 vec_substract_src_zp0_1 = vec_sub(vecf0_1, zero_point_vec0); + vfloat32 vec_substract_src_zp1_1 = vec_sub(vecf1_1, zero_point_vec1); + Vectorized vf1_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_1), + vec_mul(scale_vec1, vec_substract_src_zp1_1)}; + + vfloat32 vec_substract_src_zp0_2 = vec_sub(vecf0_2, zero_point_vec0); + vfloat32 vec_substract_src_zp1_2 = vec_sub(vecf1_2, zero_point_vec1); + Vectorized vf2_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_2), + vec_mul(scale_vec1, vec_substract_src_zp1_2)}; + + vfloat32 vec_substract_src_zp0_3 = vec_sub(vecf0_3, zero_point_vec0); + vfloat32 vec_substract_src_zp1_3 = vec_sub(vecf1_3, zero_point_vec1); + Vectorized vf3_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_3), + vec_mul(scale_vec1, vec_substract_src_zp1_3)}; + + return {vf0_zp, vf1_zp, vf2_zp, vf3_zp}; + } + + float_vec_return_type C10_ALWAYS_INLINE + dequantize(Vectorized scale, Vectorized zero_point) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return { + Vectorized{ + (vecf0_0 - zero_point0) * scale_vec0, + (vecf1_0 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_1 - zero_point0) * scale_vec0, + (vecf1_1 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_2 - zero_point0) * scale_vec0, + (vecf1_2 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_3 - zero_point0) * scale_vec0, + (vecf1_3 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 vec_inverse = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vuint32 vmin = vec_splats(min_val); + // vuint32 vmax = vec_splats(max_val); + Vectorized vf0 = rhs[0]; + Vectorized vf1 = rhs[1]; + Vectorized vf2 = rhs[2]; + Vectorized vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, vec_inverse); + vecf1 = vec_mul(vecf1, vec_inverse); + vecf2 = vec_mul(vecf2, vec_inverse); + vecf3 = vec_mul(vecf3, vec_inverse); + + vecf4 = vec_mul(vecf4, vec_inverse); + vecf5 = vec_mul(vecf5, vec_inverse); + vecf6 = vec_mul(vecf6, vec_inverse); + vecf7 = vec_mul(vecf7, vec_inverse); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vectorized C10_ALWAYS_INLINE + relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized C10_ALWAYS_INLINE relu6( + Vectorized zero_point, + Vectorized q_six) const { + vuint8 max0 = vec_max(_vec0, zero_point._vec0); + vuint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecBshi0 = vec_unpackh((vint8)b._vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + vint16 vecBshi1 = vec_unpackl((vint8)b._vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecBshi2 = vec_unpackh((vint8)b._vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + vint16 vecBshi3 = vec_unpackl((vint8)b._vec1); + + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecBshi0 = vec_and(vecBshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + vecBshi1 = vec_and(vecBshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecBshi2 = vec_and(vecBshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + vecBshi3 = vec_and(vecBshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vectorized(veci0 - vecBi0, veci1 - vecBi1), + Vectorized(veci2 - vecBi2, veci3 - vecBi3), + Vectorized(veci4 - vecBi4, veci5 - vecBi5), + Vectorized(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /) + DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..136b68911061f5033bb9b4a4a52dd04ff42c5208 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h @@ -0,0 +1,526 @@ +#pragma once +#include +#include +#include + +#if defined(__clang__) +typedef __vector __bool char vbool8; +typedef __vector __bool short vbool16; +typedef __vector __bool int vbool32; +typedef __vector __bool long long vbool64; +using vint8 = __attribute__((vector_size(16))) signed char; +using vint16 = __attribute__((vector_size(16))) signed short; +using vint32 = __attribute__((vector_size(16))) signed int; +using vint64 = __attribute__((vector_size(16))) signed long long; +using vuint8 = __attribute__((vector_size(16))) unsigned char; +using vuint16 = __attribute__((vector_size(16))) unsigned short; +using vuint32 = __attribute__((vector_size(16))) unsigned int; +using vuint64 = __attribute__((vector_size(16))) unsigned long long; +using vfloat32 = __attribute__((vector_size(16))) float; +using vfloat64 = __attribute__((vector_size(16))) double; +#else +using vbool8 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; +using vbool16 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; +using vbool32 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) int; +using vbool64 = __attribute__((altivec(vector__))) +__attribute__((altivec(bool__))) long long; +using vint8 = __attribute__((altivec(vector__))) signed char; +using vint16 = __attribute__((altivec(vector__))) signed short; +using vint32 = __attribute__((altivec(vector__))) signed int; +using vint64 = __attribute__((altivec(vector__))) signed long long; +using vuint8 = __attribute__((altivec(vector__))) unsigned char; +using vuint16 = __attribute__((altivec(vector__))) unsigned short; +using vuint32 = __attribute__((altivec(vector__))) unsigned int; +using vuint64 = __attribute__((altivec(vector__))) unsigned long long; +using vfloat32 = __attribute__((altivec(vector__))) float; +using vfloat64 = __attribute__((altivec(vector__))) double; +#endif + +#if !defined(vec_float) +C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) { + vfloat32 vec_out; + __asm__("xvcvsxwsp %x0,%x1" : "=wf"(vec_out) : "wa"(vec_in)); + return vec_out; +} +#endif + +#if !defined(vec_signed) +C10_ALWAYS_INLINE vint32 vec_signed(const vfloat32& vec_in) { + vint32 vec_out; + __asm__("xvcvspsxws %x0,%x1" : "=wa"(vec_out) : "wf"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vint64 vec_signed(const vfloat64& vec_in) { + vint64 vec_out; + __asm__("xvcvdpsxds %x0,%x1" : "=wa"(vec_out) : "wd"(vec_in)); + return vec_out; +} +#endif + +#if !defined(vec_neg) +C10_ALWAYS_INLINE vfloat32 vec_neg(const vfloat32& vec_in) { + vfloat32 vec_out; + __asm__("xvnegsp %x0,%x1" : "=wf"(vec_out) : "wf"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vfloat64 vec_neg(const vfloat64& vec_in) { + vfloat64 vec_out; + __asm__("xvnegdp %x0,%x1" : "=wd"(vec_out) : "wd"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vint16 vec_neg(const vint16& vec_in) { + vint16 vint0 = {0, 0, 0, 0, 0, 0, 0, 0}; + return vec_vsubuhm(vint0, vec_in); +} + +C10_ALWAYS_INLINE vint32 vec_neg(const vint32& vec_in) { + vint32 vint0 = {0, 0, 0, 0}; + return vec_vsubuwm(vint0, vec_in); +} + +C10_ALWAYS_INLINE vint64 vec_neg(const vint64& vec_in) { + return -vec_in; +} +#endif + +#if !defined(vec_sldw) +template +C10_ALWAYS_INLINE vfloat32 +vec_sldw_aux(const vfloat32& vec_in0, const vfloat32& vec_in1) { + vfloat32 vec_out; + __asm("xxsldwi %x0, %x1, %x2, %3 " + : "=wa"(vec_out) + : "wa"(vec_in0), "wa"(vec_in1), "I"(C)); + return vec_out; +} + +#define vec_sldw(a, b, c) vec_sldw_aux(a, b) +#endif + +#define vec_not(a) vec_nor(a, a) +#if defined(__clang__) && !defined(vec_splats) +C10_ALWAYS_INLINE vint64 vec_splats(const int64_t& a) { + return vec_splats(a); +} +#endif +// Vectorized min/max which return a if any operand is nan +template +C10_ALWAYS_INLINE T vec_min_nan(const T& a, const T& b) { + return vec_min(a, b); +} +template +C10_ALWAYS_INLINE T vec_max_nan(const T& a, const T& b) { + return vec_max(a, b); +} + +// Specializations for float/double taken from Eigen +template <> +C10_ALWAYS_INLINE vfloat32 +vec_min_nan(const vfloat32& a, const vfloat32& b) { + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE + // regarding NaN + vfloat32 ret; + __asm__("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} +// Specializations for float/double taken from Eigen +template <> +C10_ALWAYS_INLINE vfloat32 +vec_max_nan(const vfloat32& a, const vfloat32& b) { + // NOTE: about 10% slower than vec_max, but consistent with std::min and SSE + // regarding NaN + vfloat32 ret; + __asm__("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} + +template <> +C10_ALWAYS_INLINE vfloat64 +vec_min_nan(const vfloat64& a, const vfloat64& b) { + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE + // regarding NaN + vfloat64 ret; + __asm__("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} +template <> +C10_ALWAYS_INLINE vfloat64 +vec_max_nan(const vfloat64& a, const vfloat64& b) { + // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE + // regarding NaN + vfloat64 ret; + __asm__("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} + +// Vectorizes min/max function which returns nan if any side is nan +#define C10_VSX_VEC_NAN_PROPAG(name, type, btype, func) \ + C10_ALWAYS_INLINE type name(const type& a, const type& b) { \ + type tmp = func(a, b); \ + btype nan_a = vec_cmpne(a, a); \ + btype nan_b = vec_cmpne(b, b); \ + tmp = vec_sel(tmp, a, nan_a); \ + return vec_sel(tmp, b, nan_b); \ + } + +C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat32, vbool32, vec_min) +C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat32, vbool32, vec_max) +C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat64, vbool64, vec_min) +C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat64, vbool64, vec_max) + +#undef C10_VSX_VEC_NAN_PROPAG + +#define DEFINE_MEMBER_UNARY_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op() const { \ + return Vectorized{func(_vec0), func(_vec1)}; \ + } + +#define DEFINE_MEMBER_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + return Vectorized{ \ + func(_vec0, other._vec0), func(_vec1, other._vec1)}; \ + } + +#define DEFINE_MEMBER_BITWISE_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + return Vectorized{ \ + func(_vecb0, other._vecb0), func(_vecb1, other._vecb1)}; \ + } + +#define DEFINE_MEMBER_TERNARY_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op( \ + const Vectorized& b, const Vectorized& c) const { \ + return Vectorized{ \ + func(_vec0, b._vec0, c._vec0), func(_vec1, b._vec1, c._vec1)}; \ + } + +#define DEFINE_MEMBER_EMULATE_BINARY_OP(op, op_type, binary_op) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& b) \ + const { \ + Vectorized::vec_internal_type ret_0; \ + Vectorized::vec_internal_type ret_1; \ + for (int i = 0; i < Vectorized::size() / 2; i++) { \ + ret_0[i] = _vec0[i] binary_op b._vec0[i]; \ + ret_1[i] = _vec1[i] binary_op b._vec1[i]; \ + } \ + return Vectorized{ret_0, ret_1}; \ + } + +#define DEFINE_MEMBER_OP_AND_ONE(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + using vvtype = Vectorized::vec_internal_type; \ + const vvtype v_one = vec_splats(static_cast(1.0)); \ + vvtype ret0 = (vvtype)func(_vec0, other._vec0); \ + vvtype ret1 = (vvtype)func(_vec1, other._vec1); \ + return Vectorized{vec_and(ret0, v_one), vec_and(ret1, v_one)}; \ + } + +#define DEFINE_CLAMP_FUNCS(operand_type) \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return Vectorized{ \ + vec_min_nan(vec_max_nan(a.vec0(), min.vec0()), max.vec0()), \ + vec_min_nan(vec_max_nan(a.vec1(), min.vec1()), max.vec1())}; \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_min( \ + const Vectorized& a, \ + const Vectorized& min) { \ + return Vectorized{ \ + vec_max_nan(a.vec0(), min.vec0()), vec_max_nan(a.vec1(), min.vec1())}; \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_max( \ + const Vectorized& a, \ + const Vectorized& max) { \ + return Vectorized{ \ + vec_min_nan(a.vec0(), max.vec0()), vec_min_nan(a.vec1(), max.vec1())}; \ + } + +#define DEFINE_REINTERPRET_CAST_FUNCS( \ + first_type, cast_type, cast_inner_vector_type) \ + template <> \ + C10_ALWAYS_INLINE Vectorized cast( \ + const Vectorized& src) { \ + return Vectorized{ \ + (cast_inner_vector_type)src.vec0(), \ + (cast_inner_vector_type)src.vec1()}; \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(first_type) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, double, vfloat64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, float, vfloat32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int64_t, vint64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int32_t, vint32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int16_t, vint16) + +// it can be used to emulate blend faster +constexpr int blendChoice( + uint32_t mask, + uint32_t half1 = 0xF, + uint32_t half2 = 0xF0) { + uint32_t none = 0; + uint32_t both = half1 | half2; + // clamp it between 0 and both + mask = mask & both; + // return (a._vec0, a._vec1) + if (mask == none) + return 0; + // return (b._vec0,b._vec1) + else if (mask == both) + return 1; + // return (b._vec0,a._vec1) + else if (mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (mask > 0 && mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((mask & half1) == 0 && mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((mask & half1) == half1 && mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +constexpr int blendChoiceDbl(uint32_t mask) { + // clamp it 0 and 0xF + return blendChoice(mask, 0x3, 0xC); +} + +constexpr vbool32 VsxMask1(uint32_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (vbool32){g0, g1, g2, g3}; +} + +constexpr vbool32 VsxMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xFF) >> 4; + return VsxMask1(mask2); +} + +constexpr vbool64 VsxDblMask1(uint32_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (vbool64){g0, g1}; +} + +constexpr vbool64 VsxDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +constexpr int maskForComplex(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + if (mask & 4) + complex_mask |= (3 << 4); + if (mask & 8) + complex_mask |= (3 << 6); + return complex_mask; +} + +constexpr int maskForComplexDbl(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + return complex_mask; +} + +constexpr int blendChoiceComplex(uint32_t mask) { + return blendChoice(maskForComplex(mask)); +} + +constexpr int blendChoiceComplexDbl(uint32_t mask) { + return blendChoiceDbl(maskForComplexDbl(mask)); +} + +constexpr vbool32 VsxComplexMask1(uint32_t mask) { + return VsxMask1(maskForComplex(mask)); +} + +constexpr vbool32 VsxComplexMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxMask1(maskForComplex(mask2)); +} + +constexpr vbool64 VsxComplexDblMask1(uint32_t mask) { + return VsxDblMask1(mask); +} + +constexpr vbool64 VsxComplexDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +// constants +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +// +constexpr int offset0 = 0; +constexpr int offset16 = 16; + +// #Constants +const vuint8 mask_zero_bits = vuint8{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 96, + 64, + 32, + 0}; + +const vuint8 swap_mask = + vuint8{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; + +const vint32 v0x7f = vec_splats(0x7f); +const vint32 vi_0 = vec_splats((int)(0)); +const vint32 vi_1 = vec_splats((int)1); +const vint32 vi_2 = vec_splats((int)2); +const vint32 vi_4 = vec_splats((int)4); +const vint32 vi_inv1 = vec_splats((int)~1); +const vuint32 vu_29 = vec_splats(29u); +const vuint32 vu_23 = vec_splats(23u); + +const vbool32 inv_mant_mask = (vbool32)vec_splats((unsigned int)~0xff800000); +const vbool32 sign_mask = (vbool32)vec_splats((int)0x80000000); +const vbool32 real_mask = vbool32{0xFFFFFFFF, 0x0, 0xFFFFFFFF, 0x0}; +const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF}; +const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000}; +const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0}; + +const vbool64 vd_sign_mask = vbool64{0x8000000000000000, 0x8000000000000000}; +const vbool64 vd_imag_mask = vbool64{0x0, 0xFFFFFFFFFFFFFFFF}; +const vbool64 vd_real_mask = vbool64{0xFFFFFFFFFFFFFFFF, 0x0}; +const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000}; +const vbool64 vd_rsign_mask = vbool64{0x8000000000000000, 0x0}; + +const vfloat32 zero = vec_splats(0.f); +const vfloat32 half = vec_splats(0.5f); +const vfloat32 one = vec_splats(1.f); +const vfloat32 two = vec_splats(2.0f); +const vfloat32 _4div_pi = vec_splats(1.27323954473516f); +const vfloat32 v_inf = (vfloat32)vec_splats(0x7f800000u); +const vfloat32 v_minus_inf = + vfloat32{0xff800000u, 0xff800000u, 0xff800000u, 0xff800000u}; +const vfloat32 v_nan = (vfloat32)vec_splats(0x7fffffff); +const vfloat32 log10e_inv = vec_splats(0.43429448190325176f); +const vfloat32 log2e_inv = vec_splats(1.4426950408889634f); +const vfloat32 log2eB_inv = vec_splats(1.442695036924675f); +const vfloat32 cephes_SQRTHF = vec_splats(0.707106781186547524f); +const vfloat32 coscof_p0 = vec_splats(2.443315711809948E-005f); +const vfloat32 coscof_p1 = vec_splats(-1.388731625493765E-003f); +const vfloat32 coscof_p2 = vec_splats(4.166664568298827E-002f); +const vfloat32 exp_hi = vec_splats(104.f); +const vfloat32 exp_lo = vec_splats(-104.f); +const vfloat32 exp_p0 = vec_splats(0.000198527617612853646278381f); +const vfloat32 exp_p1 = vec_splats((0.00139304355252534151077271f)); +const vfloat32 exp_p2 = vec_splats(0.00833336077630519866943359f); +const vfloat32 exp_p3 = vec_splats(0.0416664853692054748535156f); +const vfloat32 exp_p4 = vec_splats(0.166666671633720397949219f); +const vfloat32 exp_p5 = vec_splats(0.5f); +const vfloat32 log_p0 = vec_splats(7.0376836292E-2f); +const vfloat32 log_p1 = vec_splats(-1.1514610310E-1f); +const vfloat32 log_p2 = vec_splats(1.1676998740E-1f); +const vfloat32 log_p3 = vec_splats(-1.2420140846E-1f); +const vfloat32 log_p4 = vec_splats(+1.4249322787E-1f); +const vfloat32 log_p5 = vec_splats(-1.6668057665E-1f); +const vfloat32 log_p6 = vec_splats(+2.0000714765E-1f); +const vfloat32 log_p7 = vec_splats(-2.4999993993E-1f); +const vfloat32 log_p8 = vec_splats(+3.3333331174E-1f); +const vfloat32 log_q1 = vec_splats(-2.12194440e-4f); +const vfloat32 log_q2 = vec_splats(0.693359375f); +const vfloat32 max_logf = vec_splats(88.02969187150841f); +const vfloat32 max_numf = + vec_splats(1.7014117331926442990585209174225846272e38f); +const vfloat32 min_inf = (vfloat32)vec_splats(0xff800000u); +const vfloat32 min_norm_pos = (vfloat32)vec_splats(0x0800000u); +const vfloat32 minus_cephes_dp1 = vec_splats(-0.78515625f); +const vfloat32 minus_cephes_dp2 = vec_splats(-2.4187564849853515625e-4f); +const vfloat32 minus_cephes_dp3 = vec_splats(-3.77489497744594108e-8f); +const vfloat32 negln2f_hi = vec_splats(-0.693145751953125f); +const vfloat32 negln2f_lo = vec_splats(-1.428606765330187045e-06f); +const vfloat32 p0 = vec_splats(2.03721912945E-4f); +const vfloat32 p1 = vec_splats(8.33028376239E-3f); +const vfloat32 p2 = vec_splats(1.66667160211E-1f); +const vfloat32 sincof_p0 = vec_splats(-1.9515295891E-4f); +const vfloat32 sincof_p1 = vec_splats(8.3321608736E-3f); +const vfloat32 sincof_p2 = vec_splats(-1.6666654611E-1f); +const vfloat32 tanh_0p625 = vec_splats(0.625f); +const vfloat32 tanh_half_max = vec_splats(44.014845935754205f); +const vfloat32 tanh_p0 = vec_splats(-5.70498872745E-3f); +const vfloat32 tanh_p1 = vec_splats(2.06390887954E-2f); +const vfloat32 tanh_p2 = vec_splats(-5.37397155531E-2f); +const vfloat32 tanh_p3 = vec_splats(1.33314422036E-1f); +const vfloat32 tanh_p4 = vec_splats(-3.33332819422E-1f); +const vfloat32 vcheck = vec_splats((float)(1LL << 24)); +const vfloat32 imag_one = vfloat32{0.f, 1.f, 0.f, 1.f}; +const vfloat32 imag_half = vfloat32{0.f, 0.5f, 0.f, 0.5f}; +const vfloat32 sqrt2_2 = vfloat32{ + 0.70710676908493042f, + 0.70710676908493042, + 0.70710676908493042, + 0.70710676908493042}; +const vfloat32 pi_2 = vfloat32{M_PI / 2, 0.0, M_PI / 2, 0.0}; +const vfloat32 vf_89 = vfloat32{89.f, 89.f, 89.f, 89.f}; +const vfloat64 vd_one = vec_splats(1.0); +const vfloat64 vd_zero = vec_splats(0.0); +const vfloat64 vd_log10e_inv = vec_splats(0.43429448190325176); +const vfloat64 vd_log2e_inv = vec_splats(1.4426950408889634); +const vfloat64 vd_imag_one = vfloat64{0.0, 1.0}; +const vfloat64 vd_imag_half = vfloat64{0.0, 0.5}; +const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757}; +const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0}; + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h new file mode 100644 index 0000000000000000000000000000000000000000..121f2cb3eb224fa6f64a724c654828f3e7188df5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -0,0 +1,2970 @@ +#include +#include +#include +#include +#include +#if defined(__clang__) +#include +#elif defined(__GNUC__) || defined(__GNUG__) +#include +#include +#endif +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template +constexpr bool is_zarch_implemented() { + return ( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v); +} + +template +constexpr bool is_zarch_implemented_quant() { + return ( + std::is_same_v || std::is_same_v || + std::is_same_v); +} + +template +constexpr bool is_zarch_implemented_complex() { + return std::is_same_v> || + std::is_same_v>; +} + +constexpr int offset0 = 0; +constexpr int offset16 = 16; + +template +struct VecBinaryType { + using type __attribute__((vector_size(16))) = uintmax_t; +}; + +template <> +struct VecBinaryType<8> { + using type = __attribute__((vector_size(16))) unsigned long long; +}; + +template <> +struct VecBinaryType<4> { + using type = __attribute__((vector_size(16))) unsigned int; +}; + +template <> +struct VecBinaryType<2> { + using type = __attribute__((vector_size(16))) unsigned short; +}; + +template <> +struct VecBinaryType<1> { + using type = __attribute__((vector_size(16))) unsigned char; +}; + +template +struct VecInnerType { + using Type __attribute__((vector_size(16))) = T; + using BinaryType = typename VecBinaryType::type; + using ElementType = T; + static constexpr int size = 16 / sizeof(T); +}; + +// define for int64_t properly for load +template <> +struct VecInnerType { + using Type = __attribute__((vector_size(16))) signed long long; + using ElementType = signed long long; + using BinaryType = typename VecBinaryType::type; + static constexpr int size = 16 / sizeof(signed long long); +}; + +template +using ZSimdVect = typename VecInnerType::Type; +template +using ZSimdVectBinary = typename VecInnerType::BinaryType; +template +using ZSimdVectElement = typename VecInnerType::ElementType; + +constexpr int blendChoiceInner( + const uint64_t mask, + const uint64_t half1 = 0xF, + const uint64_t half2 = 0xF0) { + uint64_t none = 0; + uint64_t both = half1 | half2; + // clamp it between 0 and both + auto res_mask = mask & both; + // return (a._vec0, a._vec1) + if (res_mask == none) + return 0; + // return (b._vec0,b._vec1) + else if (res_mask == both) + return 1; + // return (b._vec0, a._vec1) + else if (res_mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (res_mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (res_mask > 0 && res_mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((res_mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((res_mask & half1) == 0 && res_mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((res_mask & half1) == half1 && res_mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +template +constexpr int blendChoice(const uint64_t mask) { + static_assert(Z < 1 || Z > 8, "not implemented"); + return blendChoiceInner(mask); +} + +template <> +constexpr int blendChoice<1>(const uint64_t mask) { + return blendChoiceInner(mask, 0x0000FFFF, 0xFFFF0000); +} + +template <> +constexpr int blendChoice<2>(const uint64_t mask) { + return blendChoiceInner(mask, 0x00FF, 0xFF00); +} + +template <> +constexpr int blendChoice<4>(const uint64_t mask) { + return blendChoiceInner(mask, 0xF, 0xF0); +} + +template <> +constexpr int blendChoice<8>(const uint64_t mask) { + // clamp it 0 and 0xF + return blendChoiceInner(mask, 0x3, 0xC); +} + +template +constexpr auto GetMask1(const uint64_t mask) { + return typename VecBinaryType::type{}; +} + +template +constexpr auto GetMask2(const uint64_t mask) { + return typename VecBinaryType::type{}; +} + +template <> +constexpr auto GetMask1<1>(const uint64_t mask) { + constexpr uint8_t t = (int)0xFF; + uint8_t g0 = (mask & 1) * t; + uint8_t g1 = ((mask & 2) >> 1) * t; + uint8_t g2 = ((mask & 4) >> 2) * t; + uint8_t g3 = ((mask & 8) >> 3) * t; + uint8_t g4 = ((mask & 16) >> 4) * t; + uint8_t g5 = ((mask & 32) >> 5) * t; + uint8_t g6 = ((mask & 64) >> 6) * t; + uint8_t g7 = ((mask & 128) >> 7) * t; + uint8_t g8 = ((mask & 256) >> 8) * t; + uint8_t g9 = ((mask & 512) >> 9) * t; + uint8_t g10 = ((mask & 1024) >> 10) * t; + uint8_t g11 = ((mask & 2048) >> 11) * t; + uint8_t g12 = ((mask & 4096) >> 12) * t; + uint8_t g13 = ((mask & 8192) >> 13) * t; + uint8_t g14 = ((mask & 16384) >> 14) * t; + uint8_t g15 = ((mask & 32768) >> 15) * t; + return (typename VecBinaryType<1>::type){ + g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15}; +} + +template <> +constexpr auto GetMask2<1>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFFFFFFFF) >> 16; + return GetMask1<1>(mask2); +} + +template <> +constexpr auto GetMask1<2>(const uint64_t mask) { + constexpr uint16_t t = (int)0xFFFF; + uint16_t g0 = (mask & 1) * t; + uint16_t g1 = ((mask & 2) >> 1) * t; + uint16_t g2 = ((mask & 4) >> 2) * t; + uint16_t g3 = ((mask & 8) >> 3) * t; + uint16_t g4 = ((mask & 16) >> 4) * t; + uint16_t g5 = ((mask & 32) >> 5) * t; + uint16_t g6 = ((mask & 64) >> 6) * t; + uint16_t g7 = ((mask & 128) >> 7) * t; + return (typename VecBinaryType<2>::type){g0, g1, g2, g3, g4, g5, g6, g7}; +} + +template <> +constexpr auto GetMask2<2>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFFFF) >> 8; + return GetMask1<2>(mask2); +} + +template <> +constexpr auto GetMask1<4>(const uint64_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (typename VecBinaryType<4>::type){g0, g1, g2, g3}; +} + +template <> +constexpr auto GetMask2<4>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFF) >> 4; + return GetMask1<4>(mask2); +} + +template <> +constexpr auto GetMask1<8>(const uint64_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (typename VecBinaryType<8>::type){g0, g1}; +} + +template <> +constexpr auto GetMask2<8>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xF) >> 2; + return GetMask1<8>(mask2); +} + +template +constexpr int maskForComplex(uint32_t mask) { + return 0; +} + +template <> +constexpr int maskForComplex<8>(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + if (mask & 4) + complex_mask |= (3 << 4); + if (mask & 8) + complex_mask |= (3 << 6); + return complex_mask; +} + +template <> +constexpr int maskForComplex<16>(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + return complex_mask; +} + +template > +constexpr int blend_choice() { + return 0xAA; +} + +template <> +constexpr int blend_choice>() { + return 0x0A; +} + +constexpr int64_t allbitset(int16_t x) { + int64_t onex = 1; + return (onex << x) - onex; +} + +namespace { /* unnamed namespace */ + +ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) { + constexpr ZSimdVectBinary mergee_mask{ + 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27}; + return vec_perm(x, y, mergee_mask); +} + +ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) { + return vec_mergeh(x, y); +} + +ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) { + constexpr ZSimdVectBinary mergeo_mask{ + 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31}; + return vec_perm(x, y, mergeo_mask); +} + +ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) { + return vec_mergel(x, y); +} + +} /* unnamed namespace */ + +// +template +constexpr auto GetBpermZeroMask() { + return ZSimdVectBinary{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 96, + 64, + 32, + 0}; +} + +template <> +constexpr auto GetBpermZeroMask() { + return ZSimdVectBinary{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 64, + 0}; +} + +constexpr auto GetSwapMaskFloat() { + return ZSimdVectBinary{ + 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; +} + +template +struct is_vec_specialized_for()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using value_type = T; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using size_type = int; + // because of gcc inconsistency for int64_t we are obliged to use this, not + // value_type + using ElementType = ZSimdVectElement; + using vinner_data = std::pair; + + private: + vtype _vec0; + vtype _vec1; + + public: + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(ElementType); + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vtype v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(const vinner_data& v) + : _vec0{v.first}, _vec1{v.second} {} + C10_ALWAYS_INLINE Vectorized(vtype v1, vtype v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(T s) + : _vec0{vec_splats((ElementType)s)}, _vec1{vec_splats((ElementType)s)} {} + + template + struct LoaduHelper { + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + __at_align__ ElementType tmp_values[size()] = {}; + std::memcpy( + tmp_values, ptr, std::min(count, size()) * sizeof(ElementType)); + + return { + vec_xl(offset0, &(tmp_values[0])), + vec_xl(offset16, &(tmp_values[0]))}; + } + }; + + template + struct LoaduHelper { + static Vectorized C10_ALWAYS_INLINE + loadu(const ElementType* ptr, int count = size()) { + if (count == size()) { + return {vec_xl(offset0, ptr), vec_xl(offset16, ptr)}; + } + + __at_align__ ElementType tmp_values[size()] = {}; + std::memcpy( + tmp_values, ptr, std::min(count, size()) * sizeof(ElementType)); + + return { + vec_xl(offset0, &(tmp_values[0])), + vec_xl(offset16, &(tmp_values[0]))}; + } + }; + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return LoaduHelper::loadu(ptr, count); + } + + template + static Vectorized C10_ALWAYS_INLINE loadu_one_fourth(const U* ptr) { + // load only first 8 bytes + // only intended to be used with uint8_t + return loadu(ptr, 8 / sizeof(ElementType)); + } + + template + struct StoreHelper { + static void C10_ALWAYS_INLINE + store(const Vectorized& vec, U* ptr, int count = size()) { + if (count > 0) { + __at_align__ ElementType tmp_values[size()]; + vec_xst(vec._vec0, offset0, &(tmp_values[0])); + vec_xst(vec._vec1, offset16, &(tmp_values[0])); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(ElementType)); + } + } + }; + + template + struct StoreHelper { + static void C10_ALWAYS_INLINE + store(const Vectorized& vec, ElementType* ptr, int count = size()) { + if (count == size()) { + vec_xst(vec._vec0, offset0, ptr); + vec_xst(vec._vec1, offset16, ptr); + } else if (count > 0) { + __at_align__ ElementType tmp_values[size()]; + vec_xst(vec._vec0, offset0, &(tmp_values[0])); + vec_xst(vec._vec1, offset16, &(tmp_values[0])); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(ElementType)); + } + } + }; + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + return StoreHelper::store(*this, ptr, count); + } + + C10_ALWAYS_INLINE const vtype& vec0() const { + return _vec0; + } + + C10_ALWAYS_INLINE const vtype& vec1() const { + return _vec1; + } + + C10_ALWAYS_INLINE vinner_data data() const { + return std::make_pair<>(_vec0, _vec1); + } + + C10_ALWAYS_INLINE operator vinner_data() const { + return data(); + } + + C10_ALWAYS_INLINE const vmaskType vecb0() const { + return (vmaskType)_vec0; + } + C10_ALWAYS_INLINE const vmaskType vecb1() const { + return (vmaskType)_vec1; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, mask.vecb0()), + vec_sel(a._vec1, b._vec1, mask.vecb1())}; + } + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4) + : _vec0{s1, s2}, _vec1{s3, s4} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4, T s5, T s6, T s7, T s8) + : _vec0{s1, s2, s3, s4}, _vec1{s5, s6, s7, s8} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized( + T s1, + T s2, + T s3, + T s4, + T s5, + T s6, + T s7, + T s8, + T s9, + T s10, + T s11, + T s12, + T s13, + T s14, + T s15, + T s16) + : _vec0{s1, s2, s3, s4, s5, s6, s7, s8}, + _vec1{s9, s10, s11, s12, s13, s14, s15, s16} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized( + T s1, + T s2, + T s3, + T s4, + T s5, + T s6, + T s7, + T s8, + T s9, + T s10, + T s11, + T s12, + T s13, + T s14, + T s15, + T s16, + T s17, + T s18, + T s19, + T s20, + T s21, + T s22, + T s23, + T s24, + T s25, + T s26, + T s27, + T s28, + T s29, + T s30, + T s31, + T s32) + : _vec0{s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16}, + _vec1{ + s17, + s18, + s19, + s20, + s21, + s22, + s23, + s24, + s25, + s26, + s27, + s28, + s29, + s30, + s31, + s32} {} + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized(base, base + step, base + 2 * step, base + 3 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + + // blend section + template + static std::enable_if_t(mask) == 0, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t(mask) == 1, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t(mask) == 2, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t(mask) == 3, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t(mask) == 4, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t(mask) == 5, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t(mask) == 6, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_2nd = GetMask2(mask); + // generated masks + return {a._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t(mask) == 7, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_2nd = GetMask2(mask); + // generated masks + return {b._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t(mask) == 8, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + const vmaskType mask_2nd = GetMask2(mask); + return { + (vtype)vec_sel(a._vec0, b._vec0, mask_1st), + (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static inline std::enable_if_t<(Z >= C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + return b; + } + + template + static inline std::enable_if_t<(Z < C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + if (count == Z) + return blend(a, b); + else + return set_inner(a, b, count); + } + + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + if (count == 0) + return a; + return set_inner<1, size()>(a, b, count); + } + + const ElementType& operator[](int idx) const = delete; + ElementType& operator[](int idx) = delete; + + Vectorized _not() const { + return {(vtype)vec_nor(vecb0(), vecb0()), (vtype)vec_nor(vecb1(), vecb1())}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + return (*this == other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + return (*this != other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const { + return (*this > other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const { + return (*this >= other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const { + return (*this < other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const { + return (*this <= other) & Vectorized((T)1.0); + } + + template , int> = 0> + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + template , int> = 0> + Vectorized C10_ALWAYS_INLINE abs() const { + return {_vec0, _vec1}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {-_vec0, -_vec1}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._not(); + } + + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized angle() const { + return blendv( + Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0)); + } + + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + int zero_mask() const { + auto cmp = (*this == Vectorized(0)); + constexpr auto mask_zero_bits = GetBpermZeroMask(); + ZSimdVectBinary result0 = + vec_bperm_u128((ZSimdVectBinary)cmp.vecb0(), mask_zero_bits); + ZSimdVectBinary result1 = + vec_bperm_u128((ZSimdVectBinary)cmp.vecb1(), mask_zero_bits); + return (result0[0] | (result1[0] << (size() / 2))); + } + + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE rint() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return Vectorized((T)1) / (*this); + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + template , int> = 0> + inline Vectorized mapOrdinary(float (*const f)(float)) const { + float a00 = f(_vec0[0]); + float a01 = f(_vec0[1]); + float a02 = f(_vec0[2]); + float a03 = f(_vec0[3]); + float a10 = f(_vec1[0]); + float a11 = f(_vec1[1]); + float a12 = f(_vec1[2]); + float a13 = f(_vec1[3]); + return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapOrdinary(double (*const f)(double)) const { + return Vectorized(f(_vec0[0]), f(_vec0[1]), f(_vec1[0]), f(_vec1[1])); + } + + template , int> = 0> + inline Vectorized mapOrdinary( + float (*const f)(float, float), + const Vectorized& b) const { + float a00 = f(_vec0[0], b._vec0[0]); + float a01 = f(_vec0[1], b._vec0[1]); + float a02 = f(_vec0[2], b._vec0[2]); + float a03 = f(_vec0[3], b._vec0[3]); + float a10 = f(_vec1[0], b._vec1[0]); + float a11 = f(_vec1[1], b._vec1[1]); + float a12 = f(_vec1[2], b._vec1[2]); + float a13 = f(_vec1[3], b._vec1[3]); + return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapOrdinary( + double (*const f)(double, double), + const Vectorized& b) const { + return Vectorized( + f(_vec0[0], b._vec0[0]), + f(_vec0[1], b._vec0[1]), + f(_vec1[0], b._vec1[0]), + f(_vec1[1], b._vec1[1])); + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d) const { + vtype a0 = f(_vec0); + vtype a1 = f(_vec1); + return Vectorized{a0, a1}; + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d) const { + return Vectorized(d(_vec0), d(_vec1)); + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b) + const { + vtype a0 = f(_vec0, b._vec0); + vtype a1 = f(_vec1, b._vec1); + return Vectorized{a0, a1}; + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b) + const { + return Vectorized(d(_vec0, b._vec0), d(_vec1, b._vec1)); + } + + Vectorized acos() const { + return mapSleef(Sleef_acosf4_u10, Sleef_acosd2_u10); + } + Vectorized asin() const { + return mapSleef(Sleef_asinf4_u10, Sleef_asind2_u10); + } + Vectorized atan() const { + return mapSleef(Sleef_atanf4_u10, Sleef_atand2_u10); + } + Vectorized atanh() const { + return mapSleef(Sleef_atanhf4_u10, Sleef_atanhd2_u10); + } + + Vectorized erf() const { + return mapSleef(Sleef_erff4_u10, Sleef_erfd2_u10); + } + Vectorized erfc() const { + return mapSleef(Sleef_erfcf4_u15, Sleef_erfcd2_u15); + } + + Vectorized exp() const { + return mapSleef(Sleef_expf4_u10, Sleef_expd2_u10); + } + Vectorized exp2() const { + return mapSleef(Sleef_exp2f4_u10, Sleef_exp2d2_u10); + } + Vectorized expm1() const { + return mapSleef(Sleef_expm1f4_u10, Sleef_expm1d2_u10); + } + Vectorized exp_u20() const { + return exp(); + } + + Vectorized log() const { + return mapSleef(Sleef_logf4_u10, Sleef_logd2_u10); + } + Vectorized log2() const { + return mapSleef(Sleef_log2f4_u10, Sleef_log2d2_u10); + } + Vectorized log10() const { + return mapSleef(Sleef_log10f4_u10, Sleef_log10d2_u10); + } + Vectorized log1p() const { + return mapSleef(Sleef_log1pf4_u10, Sleef_log1pd2_u10); + } + + Vectorized sin() const { + return mapSleef(Sleef_sinf4_u10, Sleef_sind2_u10); + } + Vectorized sinh() const { + return mapSleef(Sleef_sinhf4_u10, Sleef_sinhd2_u10); + } + Vectorized cos() const { + return mapSleef(Sleef_cosf4_u10, Sleef_cosd2_u10); + } + Vectorized cosh() const { + return mapSleef(Sleef_coshf4_u10, Sleef_coshd2_u10); + } + + Vectorized tan() const { + return mapSleef(Sleef_tanf4_u10, Sleef_tand2_u10); + } + Vectorized tanh() const { + return mapSleef(Sleef_tanhf4_u10, Sleef_tanhd2_u10); + } + + Vectorized lgamma() const { + return mapSleef(Sleef_lgammaf4_u10, Sleef_lgammad2_u10); + } + + Vectorized atan2(const Vectorized& b) const { + return mapSleef(Sleef_atan2f4_u10, Sleef_atan2d2_u10, b); + } + Vectorized copysign(const Vectorized& sign) const { + return mapSleef(Sleef_copysignf4, Sleef_copysignd2, sign); + } + Vectorized fmod(const Vectorized& q) const { + return mapSleef(Sleef_fmodf4, Sleef_fmodd2, q); + } + + Vectorized hypot(const Vectorized& b) const { + return mapSleef(Sleef_hypotf4_u05, Sleef_hypotd2_u05, b); + } + + Vectorized pow(const Vectorized& b) const { + return mapSleef(Sleef_powf4_u10, Sleef_powd2_u10, b); + } + + Vectorized nextafter(const Vectorized& b) const { + return mapSleef(Sleef_nextafterf4, Sleef_nextafterd2, b); + } + + Vectorized erfinv() const { + return mapOrdinary(calc_erfinv); + } + + Vectorized digamma() const { + return mapOrdinary(calc_digamma); + } + + Vectorized igamma(const Vectorized& x) const { + return mapOrdinary(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapOrdinary(calc_igammac, x); + } + + Vectorized i0() const { + return mapOrdinary(calc_i0); + } + + Vectorized i0e() const { + return mapOrdinary(calc_i0e); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized minimum(const Vectorized& other) const { + return {vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)}; + } + + /* Propagates NaN if either input is a NaN. */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized minimum(const Vectorized& other) const { + Vectorized tmp = { + vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)}; + tmp = blendv(tmp, *this, isnan()); + return blendv(tmp, other, other.isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized maximum(const Vectorized& other) const { + return {vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)}; + } + + /* Propagates NaN if either input is a NaN. */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized maximum(const Vectorized& other) const { + Vectorized tmp = { + vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)}; + tmp = blendv(tmp, *this, isnan()); + return blendv(tmp, other, other.isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_min(const Vectorized& min) const { + return {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)}; + } + + /* Keeps NaN if actual value is NaN */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_min(const Vectorized& min) const { + Vectorized tmp = {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)}; + return blendv(tmp, *this, isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_max(const Vectorized& max) const { + return {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)}; + } + + /* Keeps NaN if actual value is NaN */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_max(const Vectorized& max) const { + Vectorized tmp = {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)}; + return blendv(tmp, *this, isnan()); + } + + template , int> = 0> + Vectorized swapped() const { + auto swap_mask = GetSwapMaskFloat(); + vtype v0 = vec_perm(_vec0, _vec0, swap_mask); + vtype v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized swapped() const { + vtype v0 = {_vec0[1], _vec0[0]}; + vtype v1 = {_vec1[1], _vec1[0]}; + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + static Vectorized mergee(Vectorized& first, Vectorized& second) { + return { + vec_mergee(first._vec0, second._vec0), + vec_mergee(first._vec1, second._vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + static Vectorized mergeo(Vectorized& first, Vectorized& second) { + return { + vec_mergeo(first._vec0, second._vec0), + vec_mergeo(first._vec1, second._vec1)}; + } + + static Vectorized horizontal_add_perm( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.swapped(); // 2perm + auto second_perm = second.swapped(); // 2perm + // summ + auto first_ret = first + first_perm; // 2add + auto second_ret = second + second_perm; // 2 add + // now lets choose evens + return mergee(first_ret, second_ret); // 2 mergee's + } + + static Vectorized horizontal_sub_perm( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.swapped(); // 2perm + auto second_perm = second.swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return mergee(first_ret, second_ret); // 2 mergee's + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized mergee() const { + return {vec_mergee(_vec0, _vec0), vec_mergee(_vec1, _vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized mergeo() const { + return {vec_mergeo(_vec0, _vec0), vec_mergeo(_vec1, _vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized to_vec_float_helper() const { + int32_t values[8] = { + _vec0[0], + _vec0[1], + _vec0[2], + _vec0[3], + _vec0[4], + _vec0[5], + _vec0[6], + _vec0[7], + }; + + return Vectorized{ + values[0], + values[1], + values[2], + values[3], + values[4], + values[5], + values[6], + values[7]}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized to_vec_uint8_helper() const { + // helper function for float to uint8_t conversion + uint8_t values[8] = { + static_cast(_vec0[0]), + static_cast(_vec0[1]), + static_cast(_vec0[2]), + static_cast(_vec0[3]), + static_cast(_vec1[0]), + static_cast(_vec1[1]), + static_cast(_vec1[2]), + static_cast(_vec1[3]), + }; + + return Vectorized{ + values[0], values[1], values[2], values[3], values[4], values[5], + values[6], values[7], 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, + }; + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() + b.vec0(), a.vec1() + b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() - b.vec0(), a.vec1() - b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() * b.vec0(), a.vec1() * b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator/( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() & b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() & b.vecb1())}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() | b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() | b.vecb1())}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() ^ b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() ^ b.vecb1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())} \ + ._not(); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpgt(a.vec0(), b.vec0()), vec_cmpgt(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpge(a.vec0(), b.vec0()), vec_cmpge(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmplt(a.vec0(), b.vec0()), vec_cmplt(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmple(a.vec0(), b.vec0()), vec_cmple(a.vec1(), b.vec1())}; \ + } + +ZVECTOR_OPERATORS(float) +ZVECTOR_OPERATORS(double) +ZVECTOR_OPERATORS(int8_t) +ZVECTOR_OPERATORS(uint8_t) +ZVECTOR_OPERATORS(uint16_t) +ZVECTOR_OPERATORS(int16_t) +ZVECTOR_OPERATORS(int32_t) +ZVECTOR_OPERATORS(int64_t) + +#undef ZVECTOR_OPERATORS + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator<<( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr Vectorized::ElementType max_shift = \ + sizeof(Vectorized::ElementType) * CHAR_BIT; \ + \ + Vectorized::ElementType a_array[Vectorized::size()]; \ + Vectorized::ElementType b_array[Vectorized::size()]; \ + Vectorized::ElementType c_array[Vectorized::size()]; \ + \ + a.store(a_array); \ + b.store(b_array); \ + \ + for (int i = 0; i != Vectorized::size(); i++) { \ + typex shift = b_array[i]; \ + if ((static_cast>(shift) < 0) || \ + (shift >= max_shift)) { \ + c_array[i] = 0; \ + } else { \ + c_array[i] = static_cast>(a_array[i]) \ + << shift; \ + } \ + } \ + \ + return Vectorized::loadu(c_array); \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator>>( \ + const Vectorized& a, const Vectorized& b) { \ + /* right shift value to retain sign bit for signed and no bits for \ + * unsigned */ \ + constexpr Vectorized::ElementType max_shift = \ + sizeof(typex) * CHAR_BIT - std::is_signed_v; \ + \ + Vectorized::ElementType a_array[Vectorized::size()]; \ + Vectorized::ElementType b_array[Vectorized::size()]; \ + Vectorized::ElementType c_array[Vectorized::size()]; \ + \ + a.store(a_array); \ + b.store(b_array); \ + \ + for (int i = 0; i != Vectorized::size(); i++) { \ + typex shift = b_array[i]; \ + if ((static_cast>(shift) < 0) || \ + (shift >= max_shift)) { \ + c_array[i] = a_array[i] >> max_shift; \ + } else { \ + c_array[i] = a_array[i] >> shift; \ + } \ + } \ + \ + return Vectorized::loadu(c_array); \ + } \ + \ + template <> \ + inline Vectorized operator~(const Vectorized& a) { \ + return a._not(); \ + } + +ZVECTOR_OPERATORS(int8_t) +ZVECTOR_OPERATORS(uint8_t) +ZVECTOR_OPERATORS(uint16_t) +ZVECTOR_OPERATORS(int16_t) +ZVECTOR_OPERATORS(int32_t) +ZVECTOR_OPERATORS(int64_t) + +#undef ZVECTOR_OPERATORS + +#define DEFINE_MAXMIN_FUNCS(operand_type) \ + template <> \ + Vectorized inline maximum( \ + const Vectorized& a, const Vectorized& b) { \ + return a.maximum(b); \ + } \ + template <> \ + Vectorized inline minimum( \ + const Vectorized& a, const Vectorized& b) { \ + return a.minimum(b); \ + } + +#define DEFINE_CLAMP_MAXMIN_FUNCS(typex) \ + DEFINE_MAXMIN_FUNCS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_min( \ + const Vectorized& a, const Vectorized& min) { \ + return a.clamp_min(min); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_max( \ + const Vectorized& a, const Vectorized& max) { \ + return a.clamp_max(max); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return clamp_max(clamp_min(a, min), max); \ + } + +DEFINE_CLAMP_MAXMIN_FUNCS(int8_t) +DEFINE_CLAMP_MAXMIN_FUNCS(uint8_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int16_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int32_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int64_t) +DEFINE_CLAMP_MAXMIN_FUNCS(float) +DEFINE_CLAMP_MAXMIN_FUNCS(double) + +namespace { /* unnamed namespace */ + +#if !defined(vec_float) || __ARCH__ < 13 +#warning \ + "float->int and int->float conversion is simulated. compile for z15 for improved performance" +inline ZSimdVect vec_int_flt(const ZSimdVect x) { + return ZSimdVect{float(x[0]), float(x[1]), float(x[2]), float(x[3])}; +} +inline ZSimdVect vec_flt_int(const ZSimdVect x) { + return ZSimdVect{int(x[0]), int(x[1]), int(x[2]), int(x[3])}; +} +#else +#define vec_int_flt vec_float +#define vec_flt_int vec_signed +#endif + +Vectorized zvec_convert_to_float(const Vectorized& x) { + return {vec_int_flt(x.vec0()), vec_int_flt(x.vec1())}; +} + +Vectorized zvec_convert_to_int(const Vectorized& x) { + return {vec_flt_int(x.vec0()), vec_flt_int(x.vec1())}; +} + +Vectorized zvec_convert_to_float(const Vectorized& x) { + return {vec_double(x.vec0()), vec_double(x.vec1())}; +} + +Vectorized zvec_convert_to_int(const Vectorized& x) { + return {vec_signed(x.vec0()), vec_signed(x.vec1())}; +} + +} /* unnamed namespace */ + +template +Vectorized cast_zvector(const Vectorized& x) { + using cast_type = typename Vectorized::vtype; + return Vectorized{(cast_type)x.vec0(), (cast_type)x.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + __builtin_s390_vfmasb(a.vec0(), b.vec0(), c.vec0()), + __builtin_s390_vfmasb(a.vec1(), b.vec1(), c.vec1())}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + __builtin_s390_vfmadb(a.vec0(), b.vec0(), c.vec0()), + __builtin_s390_vfmadb(a.vec1(), b.vec1(), c.vec1())}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return zvec_convert_to_int(src); +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return zvec_convert_to_int(src); +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + auto input_vec = Vectorized::loadu(src_a); + auto output_vec = zvec_convert_to_float(input_vec); + output_vec.store(dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + auto input_vec = Vectorized::loadu(src_a); + auto output_vec = zvec_convert_to_float(input_vec); + output_vec.store(dst_a); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +#define DEFINE_REINTERPRET_CAST_FUNCS(Fst, Cst) \ + template <> \ + C10_ALWAYS_INLINE Vectorized cast( \ + const Vectorized& src) { \ + return cast_zvector(src); \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(Fst) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, double) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, float) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int64_t) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int32_t) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int16_t) + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +#undef DEFINE_REINTERPRET_CAST_FUNCS + +template +struct unpack_type { + using type = T; +}; +template <> +struct unpack_type { + using type = int16_t; +}; +template <> +struct unpack_type { + using type = int16_t; +}; +template <> +struct unpack_type { + using type = int32_t; +}; + +template +struct pack_type { + using type = T; +}; +template <> +struct pack_type { + using type = int8_t; +}; +template <> +struct pack_type { + using type = int16_t; +}; + +namespace { /* unnamed namespace */ + +template ::type> +std::pair, Vectorized> unpack(const Vectorized& x) { + auto vec0 = vec_unpackh(x.vec0()); + auto vec1 = vec_unpackl(x.vec0()); + auto vec2 = vec_unpackh(x.vec1()); + auto vec3 = vec_unpackl(x.vec1()); + return {Vectorized{vec0, vec1}, Vectorized{vec2, vec3}}; +} + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") +template <> +std::pair, Vectorized> unpack( + const Vectorized& x) { + using typeX = typename Vectorized::vtype; + typeX vec0 = vec_unpackh(x.vec0()); + typeX vec1 = vec_unpackl(x.vec0()); + typeX vec2 = vec_unpackh(x.vec1()); + typeX vec3 = vec_unpackl(x.vec1()); + // auto mask = Vectorized(0xFF); + // vec0 = vec0 & mask; + // vec1 = vec1 & mask; + // vec2 = vec2 & mask; + // vec3 = vec3 & mask; + return { + cast_zvector(Vectorized{vec0, vec1}), + cast_zvector(Vectorized{vec2, vec3})}; +} +C10_DIAGNOSTIC_POP() + +template ::type> +Vectorized pack(const Vectorized& first, const Vectorized& second) { + auto vec0 = vec_packs(first.vec0(), first.vec1()); + auto vec1 = vec_packs(second.vec0(), second.vec1()); + return Vectorized{vec0, vec1}; +} + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") +template <> +Vectorized pack( + const Vectorized& first, + const Vectorized& second) { + auto vec0 = vec_packsu(first.vec0(), first.vec1()); + auto vec1 = vec_packsu(second.vec0(), second.vec1()); + return Vectorized{vec0, vec1}; +} +C10_DIAGNOSTIC_POP() + +} /* unnamed namespace */ + +//////////////////////////////////QUANT/////////////////////////////////////////// +template +struct is_vec_specialized_for< + T, + std::enable_if_t()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using value_type = typename T::underlying; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using vinner_type = Vectorized; + using size_type = int; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(value_type); + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + static constexpr int int_num_vecs() { + return float_num_vecs(); + } + using float_vec_return_type = std::array, float_num_vecs()>; + using int_vec_return_type = + std::array, int_num_vecs()>; + + private: + vinner_type _vec; + + public: + Vectorized() {} + + explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {} + Vectorized(const T& val) : _vec(val.val_) {} + + C10_ALWAYS_INLINE const vinner_type& vec() const { + return _vec; + } + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return Vectorized{vinner_type::loadu(ptr, count)}; + } + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + _vec.store(ptr, count); + } + + Vectorized relu(Vectorized zero_point) const { + return Vectorized{_vec.maximum(zero_point._vec)}; + } + + Vectorized relu6(Vectorized zero_point, Vectorized q_six) const { + auto ret_max = _vec.maximum(zero_point._vec); + auto ret_min = ret_max.minimum(q_six._vec); + return Vectorized{ret_min}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + int_vec_return_type widening_subtract(Vectorized b) const { + return {*this - b}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + auto float_val = zvec_convert_to_float(_vec); + return {fmadd(scale, float_val, scale_zp_premul)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + auto float_val = zvec_convert_to_float(_vec); + return {(float_val - zero_point) * scale}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized vecf = rhs[0]; + vecf = vecf * Vectorized(inverse_scale); + vecf = vecf.rint() + Vectorized((float)(zero_point)); + auto veci = zvec_convert_to_int(vecf); + + return Vectorized{veci}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 1, int> = 0> + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized vi = inp[0]; + auto vecf = zvec_convert_to_float(vi.vec()); + vecf = vecf * Vectorized(multiplier); + vecf = vecf.rint(); + auto veci = zvec_convert_to_int(vecf) + Vectorized(zero_point); + + return Vectorized{veci}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 4, int> = 0> + int_vec_return_type widening_subtract(Vectorized b) const { + auto ret16 = unpack(_vec); + auto ret16B = unpack(b.vec()); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + auto ret32B_0 = unpack(ret16B.first); + auto ret32B_1 = unpack(ret16B.second); + + return { + Vectorized(ret32_0.first - ret32B_0.first), + Vectorized(ret32_0.second - ret32B_0.second), + Vectorized(ret32_1.first - ret32B_1.first), + Vectorized(ret32_1.second - ret32B_1.second)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + // unpacking unsigned as signed + auto ret16 = unpack(_vec); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + + auto vecf_0 = zvec_convert_to_float(ret32_0.first); + auto vecf_1 = zvec_convert_to_float(ret32_0.second); + + auto vecf_2 = zvec_convert_to_float(ret32_1.first); + auto vecf_3 = zvec_convert_to_float(ret32_1.second); + return { + fmadd(scale, vecf_0, scale_zp_premul), + fmadd(scale, vecf_1, scale_zp_premul), + fmadd(scale, vecf_2, scale_zp_premul), + fmadd(scale, vecf_3, scale_zp_premul)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + // unpacking unsigned as signed + auto ret16 = unpack(_vec); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + + auto vecf_0 = zvec_convert_to_float(ret32_0.first); + auto vecf_1 = zvec_convert_to_float(ret32_0.second); + + auto vecf_2 = zvec_convert_to_float(ret32_1.first); + auto vecf_3 = zvec_convert_to_float(ret32_1.second); + + return { + (vecf_0 - zero_point) * scale, + (vecf_1 - zero_point) * scale, + (vecf_2 - zero_point) * scale, + (vecf_3 - zero_point) * scale}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto vec_inverse = Vectorized(inverse_scale); + auto vec_zero_point = Vectorized((float)zero_point); + + auto vecf0 = rhs[0]; + auto vecf2 = rhs[1]; + auto vecf4 = rhs[2]; + auto vecf6 = rhs[3]; + + vecf0 = vecf0 * vec_inverse; + vecf2 = vecf2 * vec_inverse; + vecf4 = vecf4 * vec_inverse; + vecf6 = vecf6 * vec_inverse; + + vecf0 = vecf0.rint() + vec_zero_point; + vecf2 = vecf2.rint() + vec_zero_point; + vecf4 = vecf4.rint() + vec_zero_point; + vecf6 = vecf6.rint() + vec_zero_point; + + auto veci0 = zvec_convert_to_int(vecf0); + auto veci2 = zvec_convert_to_int(vecf2); + auto veci4 = zvec_convert_to_int(vecf4); + auto veci6 = zvec_convert_to_int(vecf6); + + auto vecshi0 = pack(veci0, veci2); + auto vecshi2 = pack(veci4, veci6); + auto ret = pack(vecshi0, vecshi2); + + return Vectorized{ret}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 4, int> = 0> + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized vec_multiplier = Vectorized(multiplier); + Vectorized vec_zero_point = Vectorized(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + auto vecf0 = zvec_convert_to_float(vi0.vec()); + auto vecf2 = zvec_convert_to_float(vi1.vec()); + + auto vecf4 = zvec_convert_to_float(vi2.vec()); + auto vecf6 = zvec_convert_to_float(vi3.vec()); + + vecf0 = vecf0 * vec_multiplier; + vecf2 = vecf2 * vec_multiplier; + + vecf4 = vecf4 * vec_multiplier; + vecf6 = vecf6 * vec_multiplier; + + vecf0 = vecf0.rint(); + vecf2 = vecf2.rint(); + vecf4 = vecf4.rint(); + vecf6 = vecf6.rint(); + + auto veci0 = zvec_convert_to_int(vecf0); + auto veci2 = zvec_convert_to_int(vecf2); + auto veci4 = zvec_convert_to_int(vecf4); + auto veci6 = zvec_convert_to_int(vecf6); + + veci0 = veci0 + vec_zero_point; + veci2 = veci2 + vec_zero_point; + + veci4 = veci4 + vec_zero_point; + veci6 = veci6 + vec_zero_point; + + auto vecshi0 = pack(veci0, veci2); + auto vecshi2 = pack(veci4, veci6); + + auto ret = pack(vecshi0, vecshi2); + + return Vectorized{ret}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + return Vectorized{_vec.eq(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + return Vectorized{_vec.ne(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const { + return Vectorized{_vec.gt(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const { + return Vectorized{_vec.ge(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const { + return Vectorized{_vec.lt(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const { + return Vectorized{_vec.le(other._vec)}; + } + + Vectorized clamp_min(const Vectorized& min) const { + return Vectorized{_vec.clamp_min(min._vec)}; + } + + Vectorized clamp_max(const Vectorized& max) const { + return Vectorized{_vec.clamp_max(max._vec)}; + } + + Vectorized minimum(const Vectorized& other) const { + return Vectorized{_vec.minimum(other._vec)}; + } + + Vectorized maximum(const Vectorized& other) const { + return Vectorized{_vec.maximum(other._vec)}; + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() + b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() - b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() * b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator/( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() / b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() & b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() | b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() ^ b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() == b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() != b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() > b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() >= b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() < b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() <= b.vec()}; \ + } + +ZVECTOR_OPERATORS(c10::qint32) +ZVECTOR_OPERATORS(c10::qint8) +ZVECTOR_OPERATORS(c10::quint8) + +#undef ZVECTOR_OPERATORS + +DEFINE_CLAMP_MAXMIN_FUNCS(c10::quint8) +DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint8) +DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint32) + +template +constexpr auto real_mask() { + return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFF, 0, 0xFFFFFFFF, 0}; +} + +template <> +constexpr auto real_mask() { + return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFFFFFFFFFF, 0}; +} + +template +constexpr auto image_mask() { + return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFF, 0, 0xFFFFFFFF}; +} + +template <> +constexpr auto image_mask() { + return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFFFFFFFFFF}; +} + +template +constexpr auto rsign_mask() { + return ZSimdVect{-0.f, 0.f, -0.f, 0.f}; +} + +template <> +constexpr auto rsign_mask() { + return ZSimdVect{-0.0, 0.f}; +} + +template +constexpr auto isign_mask() { + return ZSimdVect{0.0, -0.f, 0.0, -0.f}; +} + +template <> +constexpr auto isign_mask() { + return ZSimdVect{0.0, -0.0}; +} + +template +constexpr auto image_one() { + return ZSimdVect{0, 1.f, 0, 1.f}; +} + +template <> +constexpr auto image_one() { + return ZSimdVect{0.0, 1.0}; +} + +template +constexpr auto pi_half() { + return ZSimdVect{(float)(M_PI / 2.0), 0.f, (float)(M_PI / 2.0), 0.f}; +} + +template <> +constexpr auto pi_half() { + return ZSimdVect{M_PI / 2.0, 0.0}; +} + +template +constexpr auto image_half() { + return ZSimdVect{0, 0.5f, 0, 0.5f}; +} + +template <> +constexpr auto image_half() { + return ZSimdVect{0.0, 0.5}; +} + +template +constexpr U log2e_inv() { + return static_cast(1.4426950408889634); +} + +template +constexpr U log10e_inv() { + return static_cast(0.43429448190325176); +} + +template +struct is_vec_specialized_for< + T, + std::enable_if_t()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using underline_type = decltype(std::declval().imag()); + using value_type = T; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using vinner_type = Vectorized; + using size_type = int; + using vinner_data = typename Vectorized::vinner_data; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(value_type); + } + + private: + vinner_type _vec; + + public: + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(const vinner_data& v) + : _vec{v.first, v.second} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2) + : _vec{s1.real(), s1.imag(), s2.real(), s2.imag()} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4) + : _vec{ + s1.real(), + s1.imag(), + s2.real(), + s2.imag(), + s3.real(), + s3.imag(), + s4.real(), + s4.imag()} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s) {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s, s, s) {} + + C10_ALWAYS_INLINE operator vinner_type() const { + return _vec; + } + + C10_ALWAYS_INLINE const vinner_type& vec() const { + return _vec; + } + + C10_ALWAYS_INLINE operator vinner_data() const { + return _vec.data(); + } + + C10_ALWAYS_INLINE vinner_data data() const { + return _vec.data(); + } + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return Vectorized{vinner_type::loadu(ptr, 2 * count)}; + } + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + return _vec.store(ptr, 2 * count); + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + vinner_type vmask = mask.vec(); + auto mask_complex = vinner_type( + vec_mergeh(vmask.vec0(), vmask.vec0()), + vec_mergeh(vmask.vec1(), vmask.vec1())); + return Vectorized{vinner_type::blendv(a.vec(), b.vec(), mask_complex)}; + } + + template + static auto C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int mask_complex = maskForComplex(mask); + return Vectorized{ + vinner_type::template blend(a.vec(), b.vec())}; + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized(base, base + step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + value_type(2) * step, + base + value_type(3) * step); + } + + template + static inline std::enable_if_t<(Z >= C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + return b; + } + + template + static inline std::enable_if_t<(Z < C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + if (count == Z) + return blend(a, b); + else + return set_inner(a, b, count); + } + + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + if (count == 0) + return a; + return set_inner<1, size()>(a, b, count); + } + + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(const T&)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{ + f(T(v0[0], v0[1])), + f(T(v0[2], v0[3])), + f(T(v1[0], v1[1])), + f(T(v1[2], v1[3]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(const T&)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(T)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{ + f(T(v0[0], v0[1])), + f(T(v0[2], v0[3])), + f(T(v1[0], v1[1])), + f(T(v1[2], v1[3]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(T)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + inline Vectorized mapOrdinary( + T (*const f)(const T&, const T&), + const Vectorized& b) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + auto bvec = b.vec(); + auto b0 = bvec.vec0(); + auto b1 = bvec.vec1(); + T a00 = f(T(v0[0], v0[1]), T(b0[0], b0[1])); + T a01 = f(T(v0[2], v0[3]), T(b0[2], b0[3])); + T a02 = f(T(v1[0], v1[1]), T(b1[0], b1[1])); + T a03 = f(T(v1[2], v1[3]), T(b1[2], b1[3])); + return Vectorized{a00, a01, a02, a03}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + inline Vectorized mapOrdinary( + T (*const f)(const T&, const T&), + const Vectorized& b) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + auto bvec = b.vec(); + auto b0 = bvec.vec0(); + auto b1 = bvec.vec1(); + U a00 = f(U(v0[0], v0[1]), U(b0[0], b0[1])); + U a01 = f(U(v1[0], v1[1]), U(b1[0], b1[1])); + return Vectorized{a00, a01}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg( + const typename Vectorized::vinner_type& a) { + const auto swap_mask = ZSimdVectBinary{ + 0, 1, 2, 3, 20, 21, 22, 23, 8, 9, 10, 11, 28, 29, 30, 31}; + + auto a_neg = a.neg(); + vtype v0 = vec_perm(a_neg.vec0(), a.vec0(), swap_mask); + vtype v1 = vec_perm(a_neg.vec1(), a.vec1(), swap_mask); + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg( + const typename Vectorized::vinner_type& a) { + auto a_neg = a.neg(); + vtype v0 = {a_neg.vec0()[0], a.vec0()[1]}; + vtype v1 = {a_neg.vec1()[0], a.vec1()[1]}; + return {v0, v1}; + } + + Vectorized angle2_() const { + auto b_a = _vec.swapped(); // b a + return Vectorized{_vec.atan2(b_a).swapped()}; + } + + Vectorized angle() const { + return angle2_().real(); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized{vinner_type(image_one())}; + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * + Vectorized{vinner_type(image_half())}; // i/2*ln() + } + + Vectorized atanh() const { + return mapOrdinary(std::atanh); + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) +#if 1 + vinner_type cnj = conj().vec(); + vinner_type b_a = cnj.swapped(); + vinner_type ab = cnj * b_a; + vinner_type im = ab + ab; + vinner_type val_2 = _vec * _vec; + vinner_type val_2_swapped = val_2.swapped(); + vinner_type re = vinner_type::horizontal_sub_perm(val_2, val_2_swapped); + re = vinner_type(static_cast(1)) - re; + constexpr int blend_mask = + blend_choice(); // 0x0A for complex , 0xAA for complex + vinner_type blendx = vinner_type::template blend(re, im); + auto root = Vectorized(blendx).sqrt(); + auto ln = Vectorized(Vectorized(b_a) + root).log(); + return Vectorized(ln.vec().swapped()).conj(); +#else + return mapOrdinary(std::asin); +#endif + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(vinner_type(pi_half())) - asin(); + } + + Vectorized sin() const { + return mapOrdinary(std::sin); + } + Vectorized sinh() const { + return mapOrdinary(std::sinh); + } + Vectorized cos() const { + return mapOrdinary(std::cos); + } + Vectorized cosh() const { + return mapOrdinary(std::cosh); + } + Vectorized ceil() const { + return Vectorized{_vec.ceil()}; + } + Vectorized floor() const { + return Vectorized{_vec.floor()}; + } + Vectorized neg() const { + return Vectorized(_vec.neg()); + } + Vectorized round() const { + return Vectorized{_vec.round()}; + } + Vectorized tan() const { + return mapOrdinary(std::tan); + } + Vectorized tanh() const { + return mapOrdinary(std::tanh); + } + Vectorized trunc() const { + return Vectorized{_vec.trunc()}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + auto eq = _vec.eq(other._vec); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + auto real = eq & vinner_type(real_mask()); + auto imag = (eq & vinner_type(image_mask())).swapped(); + return Vectorized{real & imag}; + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + auto ne = _vec.ne(other._vec); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + auto real = ne & vinner_type(real_mask()); + auto imag = (ne & vinner_type(image_mask())).swapped(); + return Vectorized{real | imag}; + } + + Vectorized real() const { + return Vectorized(_vec & vinner_type(real_mask())); + } + Vectorized imag_() const { + return Vectorized(_vec & vinner_type(image_mask())); + } + Vectorized imag() const { + return Vectorized{ + (_vec & vinner_type(image_mask())).swapped()}; + } + + Vectorized conj() const { + return Vectorized(_vec ^ vinner_type(isign_mask())); + } + + vinner_data abs_2_() const { + auto a = _vec * _vec; + a = a + a.swapped(); + return a.mergee().data(); + } + + static T abs_helper(const T& value) { + return T(std::abs(value)); + } + + Vectorized abs() const { + return mapOrdinary(abs_helper); + } + + Vectorized exp() const { + return mapOrdinary(std::exp); + } + + Vectorized exp2() const { + return mapOrdinary(exp2_impl); + } + + Vectorized expm1() const { + return mapOrdinary(std::expm1); + } + + Vectorized log() const { + return mapOrdinary(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return Vectorized{ret._vec * vinner_type(log2e_inv())}; + } + + Vectorized log10() const { + auto ret = log(); + return Vectorized{ret._vec * vinner_type(log10e_inv())}; + } + + Vectorized log1p() const { + return mapOrdinary(std::log1p); + } + + Vectorized sgn() const { + return mapOrdinary(at::native::sgn_impl); + } + + Vectorized pow(const Vectorized& exp) const { + return mapOrdinary(std::pow, exp); + } + + Vectorized sqrt() const { + return mapOrdinary(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + vinner_type c_d = _vec ^ vinner_type(isign_mask()); + vinner_type abs = abs_2_(); + return Vectorized{c_d / abs}; + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized lt(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized le(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized gt(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized ge(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() + b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() - b.vec()}; \ + } \ + \ + template <> \ + Vectorized inline operator*( \ + const Vectorized& a, const Vectorized& b) { \ + /* (a + bi) * (c + di) = (ac - bd) + (ad + bc)i */ \ + Vectorized::vinner_type bv = b.vec(); \ + \ + /* this is more z arch friendly than simulating horizontal from x86 */ \ + Vectorized::vinner_type vi = bv.mergeo(); \ + Vectorized::vinner_type vr = bv.mergee(); \ + vi = vi ^ \ + Vectorized::vinner_type( \ + rsign_mask::underline_type>()); \ + Vectorized::vinner_type ret = a.vec() * vr; \ + Vectorized::vinner_type vx_swapped = a.vec().swapped(); \ + ret = fmadd(vx_swapped, vi, ret); \ + \ + return Vectorized{ret}; \ + } \ + \ + template <> \ + Vectorized inline operator/( \ + const Vectorized& a, const Vectorized& b) { \ + /* Unfortunately, this breaks some tests */ \ + /* Implement it like it's done for avx2 */ \ + auto fabs_cd = b.vec().abs(); /* |c| |d| */ \ + auto fabs_dc = fabs_cd.swapped(); /* |d| |c| */ \ + auto scale = Vectorized::vinner_type{1.0} / \ + maximum(fabs_cd, fabs_dc); /* 1/sc 1/sc */ \ + auto a2 = a.vec() * scale; /* a/sc b/sc */ \ + auto b2 = b.vec() * scale; /* c/sc d/sc */ \ + auto acbd2 = a2 * b2; /* ac/sc^2 bd/sc^2 */ \ + \ + auto dc2 = b2.swapped(); /* d/sc c/sc */ \ + dc2 = Vectorized::real_neg(dc2); /* -d/|c,d| c/sc */ \ + auto adbc2 = a2 * dc2; /* -ad/sc^2 bc/sc^2 */ \ + auto sum1 = acbd2 + acbd2.swapped(); /* (ac+bd)/sc^2 (ac+bd)/sc^2 */ \ + auto sum2 = adbc2 + adbc2.swapped(); /* (bc-ad)/sc^2 (bc-ad)/sc^2 */ \ + auto res2 = Vectorized::vinner_type::mergee( \ + sum1, sum2); /* (ac+bd)/sc^2 (bc-ad)/sc^2 */ \ + \ + /* get the denominator */ \ + Vectorized::vinner_type denom2 = \ + Vectorized{b2}.abs_2_(); /* (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 */ \ + res2 = res2 / denom2; \ + return Vectorized{res2}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() & b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() | b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() ^ b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() == b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() != b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } + +ZVECTOR_OPERATORS(c10::complex) +ZVECTOR_OPERATORS(c10::complex) + +#undef ZVECTOR_OPERATORS + +template = 0> +std::pair, Vectorized> inline inner_interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + using vtype = typename Vectorized::vtype; + vtype ab00 = {a.vec0()[0], b.vec0()[0]}; + vtype ab11 = {a.vec0()[1], b.vec0()[1]}; + vtype ab2_00 = {a.vec1()[0], b.vec1()[0]}; + vtype ab2_11 = {a.vec1()[1], b.vec1()[1]}; + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11}); +} + +template = 0> +std::pair, Vectorized> inline inner_deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + using vtype = typename Vectorized::vtype; + vtype aa01 = {a.vec0()[0], a.vec1()[0]}; + vtype aa23 = {b.vec0()[0], b.vec1()[0]}; + + vtype bb_01 = {a.vec0()[1], a.vec1()[1]}; + vtype bb_23 = {b.vec0()[1], b.vec1()[1]}; + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair(Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23}); +} + +template = 0> +std::pair, Vectorized> inline inner_interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + using vtype = typename Vectorized::vtype; + vtype ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vtype ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vtype ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vtype ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233}); +} + +template = 0> +std::pair, Vectorized> inline inner_deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + using vtype = typename Vectorized::vtype; + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vtype a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vtype a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vtype aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vtype bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vtype a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vtype a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vtype aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vtype bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2}); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2< + int32_t>(const Vectorized& a, const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2< + int64_t>(const Vectorized& a, const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template +std::enable_if_t< + std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(const Vectorized& src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 64 bits + auto vec_int = src.to_vec_float_helper(); + + return zvec_convert_to_float(vec_int); +} + +template +std::enable_if_t< + std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(const Vectorized& src) { + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + auto vec_int = clamp( + zvec_convert_to_int(src), + Vectorized(min_val), + Vectorized(max_val)); + + return vec_int.to_vec_uint8_helper(); +} + +#undef DEFINE_CLAMP_MAXMIN_FUNCS +#undef DEFINE_MAXMIN_FUNCS +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h new file mode 100644 index 0000000000000000000000000000000000000000..1ece1de99e6f56fe51a9316443cd6cafc8e3645e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h @@ -0,0 +1,409 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +// clang-format off +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// clang-format on + +#include +#include +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; +} + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << "]"; + return stream; +} + +#if defined(CPU_CAPABILITY_AVX512) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castpd_ps(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castps_pd(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castsi512_ps(src); +} + +template <> +inline Vectorized cast( + const Vectorized& src) { + return _mm512_castsi512_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm512_i64gather_pd(vindex, base_addr, scale); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm512_i32gather_ps(vindex, base_addr, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF)); + auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF)); + auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm512_cvtpd_epi64(src); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm512_cvttps_epi32(src); +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm512_cvtepi64_pd(src); +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm512_cvtepi32_ps(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4); + return std::make_pair( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, + // a15} b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, + // b14, b15} + // + // return: + // {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, + // b15} + __m512i idx1 = + _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + return std::make_pair( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + // The members of indices have been written in binary format for better + // understandability + __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, + // a15, b15} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, + // a15} + // {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, + // b15} + __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = + _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_ps(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_pd(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_epi64(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = + _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_epi32(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi16( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31); + return _mm512_permutexvar_epi16(mask, v); +} + +inline __m512i flip8(const __m512i& v) { + const __m512i mask1 = _mm512_set_epi8( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15); + const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6); + auto reversed_vec = _mm512_shuffle_epi8(v, mask1); + return _mm512_permutexvar_epi64(mask2, reversed_vec); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m512i* self_ = reinterpret_cast(self.as_bytes()); + const __m512i* other_ = reinterpret_cast(other.as_bytes()); + __m512i out = _mm512_and_si512(*self_, *other_); + Vectorized ret; + // We do not have a constructer that takes __m512i, so we need to memcpy + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + +#endif // defined(CPU_CAPABILITY_AVX512) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..8ee5ad955662fffe549dbc05000a4d0444c0bc8d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -0,0 +1,1937 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +#ifndef SLEEF_CONST +#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) +#define SLEEF_CONST const +#else +#define SLEEF_CONST +#endif +#define SLEEF_CONST_OLD SLEEF_CONST +#else +#define SLEEF_CONST_OLD +#endif + +// bfloat16 conversion +static inline void cvtbf16_fp32(const __m256i& a, __m512& o) { + o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m256i cvtfp32_bf16(const __m512& src) { + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones); + auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm512_add_epi32(t_lo, vec_bias); + t_hi = _mm512_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm512_add_epi32(t_lo, lo); + t_hi = _mm512_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm512_srli_epi32(t_lo, 16); + t_hi = _mm512_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo); + t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi); + + t_lo = _mm512_packus_epi32( + t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, t_lo); +} + +static inline __m512i merge_compare_result(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + lo = _mm512_srli_epi32(lo, 16); + hi = _mm512_srli_epi32(hi, 16); + auto out = _mm512_packus_epi32(lo, hi); + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, out); +} + +// float16 conversion +static inline void cvtfp16_fp32(const __m256i& a, __m512& o) { + o = _mm512_cvtph_ps(a); +} + +static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtfp16_fp32(lo, o1); + cvtfp16_fp32(hi, o2); +} + +static inline __m256i cvtfp32_fp16(const __m512& src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) { + __m256i lo = + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i hi = + _mm512_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo)); + __m256 t_hi = _mm256_castsi256_ps(hi); + return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1)); +} + +// dtype conversion between float16/bfloat16 and float32 +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m256i& a, __m512& o); +template <> +inline void cvt_to_fp32(const __m256i& a, __m512& o) { + cvtbf16_fp32(a, o); +} +template <> +inline void cvt_to_fp32(const __m256i& a, __m512& o) { + cvtfp16_fp32(a, o); +} + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2); +template <> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { + cvtbf16_fp32(a, o1, o2); +} +template <> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { + cvtfp16_fp32(a, o1, o2); +} + +template < + typename T, + bool is_compare_op = false, + typename std::enable_if_t, int> = 0> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b); +template <> +inline __m512i cvt_from_fp32( + const __m512& a, + const __m512& b) { + return cvtfp32_bf16(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return merge_compare_result(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return cvtfp32_fp16(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return cvtfp32_fp16(a, b); +} + +template +class Vectorized16 { + static_assert( + is_reduced_floating_point_v, + "Support only float16 and bfloat16."); + + private: + __m512i values; + + public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 32; + } + Vectorized16() {} + Vectorized16(__m512i v) : values(v) {} + Vectorized16(T val) { + value_type uw = val.x; + values = _mm512_set1_epi16(uw); + } + Vectorized16( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32) { + values = _mm512_set_epi16( + val32.x, + val31.x, + val30.x, + val29.x, + val28.x, + val27.x, + val26.x, + val25.x, + val24.x, + val23.x, + val22.x, + val21.x, + val20.x, + val19.x, + val18.x, + val17.x, + val16.x, + val15.x, + val14.x, + val13.x, + val12.x, + val11.x, + val10.x, + val9.x, + val8.x, + val7.x, + val6.x, + val5.x, + val4.x, + val3.x, + val2.x, + val1.x); + } + operator __m512i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0)); + } + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) + return _mm512_loadu_si512(reinterpret_cast(ptr)); + + __mmask32 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi16(mask, ptr); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi16(ptr, mask, values); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange( + T base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + case 16: + return blend<65535>(a, b); + case 17: + return blend<131071>(a, b); + case 18: + return blend<262143>(a, b); + case 19: + return blend<524287>(a, b); + case 20: + return blend<1048575>(a, b); + case 21: + return blend<2097151>(a, b); + case 22: + return blend<4194303>(a, b); + case 23: + return blend<8388607>(a, b); + case 24: + return blend<16777215>(a, b); + case 25: + return blend<33554431>(a, b); + case 26: + return blend<67108863>(a, b); + case 27: + return blend<134217727>(a, b); + case 28: + return blend<268435455>(a, b); + case 29: + return blend<536870911>(a, b); + case 30: + return blend<1073741823>(a, b); + case 31: + return blend<2147483647>(a, b); + } + return b; + } +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wignored-qualifiers" + + Vectorized map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized isnan() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __mmask16 lo_mask, hi_mask; + __m512 zero = _mm512_set1_ps(0.0); + __m512i zeroi = _mm512_castps_si512(zero); + lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q); + lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF)); + hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q); + hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF)); + return merge_compare_result(lo, hi); + } +#pragma clang diagnostic pop + Vectorized abs() const { + return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values); + } + Vectorized angle() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto angle_lambda = [](__m512 values) { + const auto zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto non_nan_mask_vec = _mm512_mask_set1_epi32( + _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask( + _mm512_castsi512_ps(non_nan_mask_vec), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf16_u10); + } + Vectorized acosh() const { + return map(Sleef_acoshf16_u10); + } + Vectorized asin() const { + return map(Sleef_asinf16_u10); + } + Vectorized asinh() const { + return map(Sleef_asinhf16_u10); + } + Vectorized atan() const { + return map(Sleef_atanf16_u10); + } + Vectorized atanh() const { + return map(Sleef_atanhf16_u10); + } + Vectorized atan2(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f16_u10(lo, b1); + auto o2 = Sleef_atan2f16_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized copysign(const Vectorized& sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m512i mask_value = _mm512_set1_epi32(~0x80008000); + __m512i mask_signbit = _mm512_set1_epi32(0x80008000); + return Vectorized(_mm512_or_si512( + _mm512_and_si512(values, mask_value), + _mm512_and_si512(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff16_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf16_u15); + } + Vectorized erfinv() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf16_u10); + } + Vectorized exp2() const { + return map(Sleef_exp2f16_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f16_u10); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + __m512 x_lo, x_hi; + cvt_to_fp32(values, x_lo, x_hi); + __m512 q_lo, q_hi; + cvtbf16_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf16(x_lo, q_lo); + auto o2 = Sleef_fmodf16(x_hi, q_hi); + return cvt_from_fp32(o1, o2); + } + Vectorized hypot(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf16_u05(lo, b1); + auto o2 = Sleef_hypotf16_u05(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0e() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized digamma() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_digamma(tmp1[i]); + tmp2[i] = calc_digamma(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized igamma(const Vectorized& x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + + Vectorized igammac(const Vectorized& x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf16_u10); + } + Vectorized log2() const { + return map(Sleef_log2f16_u10); + } + Vectorized log10() const { + return map(Sleef_log10f16_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf16_u10); + } + Vectorized sin() const { + return map(Sleef_sinf16_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf16_u10); + } + Vectorized cos() const { + return map(Sleef_cosf16_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf16_u10); + } + Vectorized ceil() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_ceil_ps(lo); + auto o2 = _mm512_ceil_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized floor() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_floor_ps(lo); + auto o2 = _mm512_floor_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized neg() const { + return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000)); + } + Vectorized round() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_roundscale_ps( + lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = _mm512_roundscale_ps( + hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf16_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf16_u10); + } + Vectorized trunc() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = + _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = + _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf16_u10); + } + Vectorized sqrt() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_sqrt_ps(lo); + auto o2 = _mm512_sqrt_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized reciprocal() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, lo); + auto o2 = _mm512_div_ps(ones, hi); + return cvt_from_fp32(o1, o2); + } + Vectorized rsqrt() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo)); + auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi)); + return cvt_from_fp32(o1, o2); + } + Vectorized pow(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_powf16_u10(lo, b1); + auto o2 = Sleef_powf16_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + + private: + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvt_to_fp32(values, a_lo, a_hi); + cvt_to_fp32(b.values, b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); + } + + public: + Vectorized inline operator>(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator<(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator>=(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator<=(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator==(const Vectorized16& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator!=(const Vectorized16& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } +}; + +template +static inline Vectorized binary_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvt_to_fp32(__m512i(a), a_lo, a_hi); + cvt_to_fp32(__m512i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = BFloat16; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, BFloat16* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = _mm512_loadu_ps(&src[i]); + __m512 b = _mm512_loadu_ps(&src[i + 16]); + + __m512i bf = cvtfp32_bf16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, BFloat16* dst, int64_t n) { + auto load_float = [](const double* src) -> __m512 { + // Load one float vector from an array of doubles + __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); + __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = load_float(&src[i]); + __m512 b = load_float(&src[i + 16]); + + __m512i bf = cvtfp32_bf16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + cvtbf16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) { + __m512i r[8]; + // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 e0e1 e2e3 e4e5 e6e7 e8e9 + // e10e11 e12e13 e14e15 b0-b15 f0-f15 c0-c15 g0-g15 d0-d15 h0-h15 i0-i15 + // m0-m15 j0-j15 n0-n15 k0-k15 o0-o15 l0-l15 p0-p15 +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; i++) { + r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01); + r[i + 4] = + _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01); + } + + // u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 + // f8f9 e10e11 f10f11 u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 + // f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15 u2: c0c1 d0d1 c2c3 d2d3 c8c9 + // d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11 u3: c4c5 + // d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 g12g13 + // h12h13 g14g15 h14h15 i j m n k l o p +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 8; i += 2) { + u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]); + u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]); + } + + // r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 + // g8g9 h8h9 r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 + // g2g3 h2h3 e10e11 f10f11 g10g11 h10h11 r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 + // c12c13 d12d13 r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15 r4: i j k + // l m n o p + r[0] = _mm512_unpacklo_epi64(u[0], u[2]); + r[1] = _mm512_unpackhi_epi64(u[0], u[2]); + r[2] = _mm512_unpacklo_epi64(u[1], u[3]); + r[3] = _mm512_unpackhi_epi64(u[1], u[3]); + r[4] = _mm512_unpacklo_epi64(u[4], u[6]); + r[5] = _mm512_unpackhi_epi64(u[4], u[6]); + r[6] = _mm512_unpacklo_epi64(u[5], u[7]); + r[7] = _mm512_unpackhi_epi64(u[5], u[7]); + + __m512i const1 = _mm512_set_epi32( + 0x00370035, + 0x00330031, + 0x00270025, + 0x00230021, + 0x00170015, + 0x00130011, + 0x00070005, + 0x00030001, + 0x00360034, + 0x00320030, + 0x00260024, + 0x00220020, + 0x00160014, + 0x00120010, + 0x00060004, + 0x00020000); + __m512i const2 = _mm512_set_epi32( + 0x003f003d, + 0x003b0039, + 0x002f002d, + 0x002b0029, + 0x001f001d, + 0x001b0019, + 0x000f000d, + 0x000b0009, + 0x003e003c, + 0x003a0038, + 0x002e002c, + 0x002a0028, + 0x001e001c, + 0x001a0018, + 0x000e000c, + 0x000a0008); + // merge values from two regs + // 0-- 1-- + // 8-- 9-- + // 2-- 3-- + // 10-- 11-- + // 4-- 5-- + // 12-- 13-- + // 6-- 7-- + // 14-- 15-- +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; i++) { + u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]); + u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]); + } +} + +// TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16 +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 +template <> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst) { + __m256i t[16]; + // load from src to registers + // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 + // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 + // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 + // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 + // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 + // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 + // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 + // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 + // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 + // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 + // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 + // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 + // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 + // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 + // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 + // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; i++) { + t[i] = + _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); + } + + __m512i u[8]; + _transpose_mxn_half_16_16(t, u); + +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; i++) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x0)); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x01)); + } +} + +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 +template <> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst) { + __m256i t[16]; + // load from src to registers + // Same matrix indices as above transpose_mxn +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; i++) { + t[i] = + _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); + } + + __m512i u[8]; + _transpose_mxn_half_16_16(t, u); + +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; i++) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x0)); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x01)); + } +} + +static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { + // t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59 + // t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63 + // t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123 + // t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127 + // t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 + // ... 187 t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 + // 175 148 ... 191 t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 + // 234 203 235 208 ... 251 t[7]: 196 228 197 229 198 230 199 231 204 236 205 + // 237 206 238 207 239 212 ... 255 t[8]: 256 288 257 289 258 290 259 291 264 + // 296 265 297 266 298 267 299 272 ... 315 t[9]: 260 292 261 293 262 294 263 + // 295 268 300 269 301 270 302 271 303 276 ... 319 t[10]: 320 352 321 353 322 + // 354 323 355 328 360 329 361 330 362 331 363 336 ... 379 t[11]: 324 356 325 + // 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383 t[12]: 384 + // 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443 + // t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 + // ... 447 t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 + // 491 464 ... 507 t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 + // 494 463 495 468 ... 511 t[16]: 512 544 513 545 514 546 515 547 520 552 521 + // 553 522 554 523 555 528 ... 571 + // ... + // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 + // 980 ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; ++i) { + d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]); + d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]); + } + + // t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121 + // t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123 + // t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125 + // t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127 + // t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 + // ... 249 t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 + // 235 146 ... 251 t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 + // 173 205 237 148 ... 253 t[7]: 134 166 198 230 135 167 199 231 142 174 206 + // 238 143 175 207 239 150 ... 255 t[8]: 256 288 320 352 257 289 321 353 264 + // 296 328 360 265 297 329 361 272 ... 377 t[9]: 258 290 322 354 259 291 323 + // 355 266 298 330 362 267 299 331 363 274 ... 379 t[10]: 260 292 324 356 261 + // 293 325 357 268 300 332 364 269 301 333 365 276 ... 381 t[11]: 262 294 326 + // 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383 t[12]: 384 + // 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505 + // t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 + // ... 507 t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 + // 493 404 ... 509 t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 + // 431 463 495 406 ... 511 t[16]: 512 544 576 608 513 545 577 609 520 552 584 + // 616 521 553 585 617 528 ... 633 + // ... + // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 + // 918 ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; ++i) { + r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]); + r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]); + r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]); + r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]); + } + + // t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248 + // t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249 + // t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250 + // t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251 + // t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252 + // t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253 + // t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254 + // t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255 + // t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 + // ... 504 t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 + // 489 273 ... 505 t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 + // 426 458 490 274 ... 506 t[11]: 259 291 323 355 387 419 451 483 267 299 331 + // 363 395 427 459 491 275 ... 507 t[12]: 260 292 324 356 388 420 452 484 268 + // 300 332 364 396 428 460 492 276 ... 508 t[13]: 261 293 325 357 389 421 453 + // 485 269 301 333 365 397 429 461 493 277 ... 509 t[14]: 262 294 326 358 390 + // 422 454 486 270 302 334 366 398 430 462 494 278 ... 510 t[15]: 263 295 327 + // 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511 t[16]: 512 + // 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760 + // ... + // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 + // ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; ++i) { + d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]); + d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]); + d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]); + d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]); + d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]); + d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]); + d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]); + d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]); + } + + // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496 + // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497 + // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498 + // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499 + // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... + // 500 t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 + // ... 501 t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 + // 22 ... 502 t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 + // 487 23 ... 503 t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 + // 456 488 24 ... 504 t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 + // 425 457 489 25 ... 505 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 + // 394 426 458 490 26 ... 506 t[11]: 11 43 75 107 139 171 203 235 267 299 331 + // 363 395 427 459 491 27 ... 507 t[12]: 12 44 76 108 140 172 204 236 268 300 + // 332 364 396 428 460 492 28 ... 508 t[13]: 13 45 77 109 141 173 205 237 269 + // 301 333 365 397 429 461 493 29 ... 509 t[14]: 14 46 78 110 142 174 206 238 + // 270 302 334 366 398 430 462 494 30 ... 510 t[15]: 15 47 79 111 143 175 207 + // 239 271 303 335 367 399 431 463 495 31 ... 511 t[16]: 512 544 576 608 640 + // 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008 + // ... + // t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 + // ... 1023 + __m512i const1 = _mm512_set_epi64( + 0x000000000000000d, + 0x000000000000000c, + 0x0000000000000005, + 0x0000000000000004, + 0x0000000000000009, + 0x0000000000000008, + 0x0000000000000001, + 0x0000000000000000); + __m512i const2 = _mm512_set_epi64( + 0x000000000000000f, + 0x000000000000000e, + 0x0000000000000007, + 0x0000000000000006, + 0x000000000000000b, + 0x000000000000000a, + 0x0000000000000003, + 0x0000000000000002); +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; ++i) { + r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/ const1, d[i + 8]); + r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/ const2, d[i + 8]); + r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const1, d[i + 24]); + r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const2, d[i + 24]); + } + + // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 + // ... 992 t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 + // 513 545 ... 993 t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 + // 450 482 514 546 ... 994 t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 + // 387 419 451 483 515 547 ... 995 t[4]: 4 36 68 100 132 164 196 228 260 292 + // 324 356 388 420 452 484 516 548 ... 996 t[5]: 5 37 69 101 133 165 197 229 + // 261 293 325 357 389 421 453 485 517 549 ... 997 t[6]: 6 38 70 102 134 166 + // 198 230 262 294 326 358 390 422 454 486 518 550 ... 998 t[7]: 7 39 71 103 + // 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999 t[8]: 8 40 + // 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000 + // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 + // ... 1001 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 + // 490 522 554 ... 1002 t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 + // 395 427 459 491 523 555 ... 1003 t[12]: 12 44 76 108 140 172 204 236 268 + // 300 332 364 396 428 460 492 524 556 ... 1004 t[13]: 13 45 77 109 141 173 + // 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005 t[14]: 14 46 78 + // 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006 t[15]: + // 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... + // 1007 t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 + // 528 560 ... 1008 + // ... + // t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 + // ... 1023 + __m512i const3 = _mm512_set_epi64( + 0x000000000000000b, + 0x000000000000000a, + 0x0000000000000009, + 0x0000000000000008, + 0x0000000000000003, + 0x0000000000000002, + 0x0000000000000001, + 0x0000000000000000); + __m512i const4 = _mm512_set_epi64( + 0x000000000000000f, + 0x000000000000000e, + 0x000000000000000d, + 0x000000000000000c, + 0x0000000000000007, + 0x0000000000000006, + 0x0000000000000005, + 0x0000000000000004); +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; ++i) { + d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/ const3, r[i + 16]); + d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/ const4, r[i + 16]); + } +} + +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6 +template <> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst, + int M, + int N) { + // load from src + TORCH_CHECK( + M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); + __m512i r[32]; + int i; + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } + } + for (; i < 32; ++i) { + r[i] = _mm512_setzero_si512(); + } + + __m512i d[32]; + _transpose_mxn_half_32_32(r, d); + + // store to dst + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t< + std::is_same_v && + ((M <= 32 && M != 16) || (N <= 32 && N != 16)), + int> = 0> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +template <> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst, + int M, + int N) { + TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); + // load from src + __m512i r[32]; + int i; + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } + } + for (; i < 32; ++i) { + r[i] = _mm512_setzero_si512(); + } + + __m512i d[32]; + _transpose_mxn_half_32_32(r, d); + + // store to dst + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t< + std::is_same_v && + ((M <= 32 && M != 16) || (N <= 32 && N != 16)), + int> = 0> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = Half; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(min), min_lo, min_hi); + cvtfp16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +inline void convert(const Half* src, Half* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, Half* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = _mm512_loadu_ps(&src[i]); + __m512 b = _mm512_loadu_ps(&src[i + 16]); + + __m512i bf = cvtfp32_fp16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, Half* dst, int64_t n) { + auto load_float = [](const double* src) -> __m512 { + // Load one float vector from an array of doubles + __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); + __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = load_float(&src[i]); + __m512 b = load_float(&src[i + 16]); + + __m512i bf = cvtfp32_fp16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + cvtfp16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_fp16(o1, o2); +} + +#define CONVERT_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + __m512 o1, o2; \ + cvt_to_fp32(__m512i(a), o1, o2); \ + return std::make_tuple(o1, o2); \ + } \ + \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + return cvt_from_fp32(__m512(a), __m512(b)); \ + } +CONVERT_VECTORIZED_INIT(BFloat16, bfloat16) +CONVERT_VECTORIZED_INIT(Half, half) + +#else // defined(CPU_CAPABILITY_AVX512) + +#define CONVERT_NON_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr2); \ + for (const auto k : c10::irange(K)) { \ + arr[k] = c10::convert(arr2[k]); \ + } \ + return std::make_tuple( \ + Vectorized::loadu(arr), \ + Vectorized::loadu(arr + Vectorized::size())); \ + } \ + \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr); \ + b.store(arr + Vectorized::size()); \ + for (const auto k : c10::irange(K)) { \ + arr2[k] = c10::convert(arr[k]); \ + } \ + return Vectorized::loadu(arr2); \ + } +CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) +CONVERT_NON_VECTORIZED_INIT(Half, half) + +#endif // defined(CPU_CAPABILITY_AVX512) + +#if defined(CPU_CAPABILITY_AVX512) +#define LOAD_FP32_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + auto values = _mm256_loadu_si256(reinterpret_cast(data)); \ + __m512 out_values; \ + cvt_to_fp32(values, out_values); \ + out = out_values; \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + auto vec = Vectorized::loadu(data); \ + __m512 out1_values, out2_values; \ + cvt_to_fp32(vec, out1_values, out2_values); \ + out1 = out1_values; \ + out2 = out2_values; \ + } +LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) +LOAD_FP32_VECTORIZED_INIT(Half, fp16) + +#else // defined(CPU_CAPABILITY_AVX512) +#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + __at_align__ float values[Vectorized::size()]; \ + for (const auto k : c10::irange(Vectorized::size())) { \ + values[k] = data[k]; \ + } \ + out = Vectorized::loadu(values); \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + load_fp32_from_##name(data, out1); \ + data += Vectorized::size(); \ + load_fp32_from_##name(data, out2); \ + } +LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) + +#endif +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h new file mode 100644 index 0000000000000000000000000000000000000000..90c2c0f5bd7b707c2bd1ad4d4d93fe7006c1d4c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -0,0 +1,654 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m512d values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm512_setr_pd( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4) { + values = _mm512_setr_pd( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag()); + } + operator __m512d() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + // NOLINTNEXTLINE(clang-diagnostic-warning) + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_pd( + 0x03, a.values, b.values); // b0000 0001 = b0000 0011 + case 2: + return _mm512_mask_blend_pd( + 0x0C, a.values, b.values); // b0000 0010 = b0000 1100 + case 3: + return _mm512_mask_blend_pd( + 0x0F, a.values, b.values); // b0000 0011 = b0000 1111 + case 4: + return _mm512_mask_blend_pd( + 0x30, a.values, b.values); // b0000 0100 = b0011 0000 + case 5: + return _mm512_mask_blend_pd( + 0x33, a.values, b.values); // b0000 0101 = b0011 0011 + case 6: + return _mm512_mask_blend_pd( + 0x3C, a.values, b.values); // b0000 0110 = b0011 1100 + case 7: + return _mm512_mask_blend_pd( + 0x3F, a.values, b.values); // b0000 0111 = b0011 1111 + case 8: + return _mm512_mask_blend_pd( + 0xC0, a.values, b.values); // b0000 1000 = b1100 0000 + case 9: + return _mm512_mask_blend_pd( + 0xC3, a.values, b.values); // b0000 1001 = b1100 0011 + case 10: + return _mm512_mask_blend_pd( + 0xCC, a.values, b.values); // b0000 1010 = b1100 1100 + case 11: + return _mm512_mask_blend_pd( + 0xCF, a.values, b.values); // b0000 1011 = b1100 1111 + case 12: + return _mm512_mask_blend_pd( + 0xF0, a.values, b.values); // b0000 1100 = b1111 0000 + case 13: + return _mm512_mask_blend_pd( + 0xF3, a.values, b.values); // b0000 1101 = b1111 0011 + case 14: + return _mm512_mask_blend_pd( + 0xFC, a.values, b.values); // b0000 1110 = b1111 1100 + case 15: + return _mm512_mask_blend_pd( + 0xFF, a.values, b.values); // b0000 1111 = b1111 1111 + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values); + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask( + _mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + c10::complex(1) * step, + base + c10::complex(2) * step, + base + c10::complex(3) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2 * size()]; + _mm512_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512d hadd_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_add_pd( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + static inline __m512d hsub_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_sub_pd( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + __m512d abs_2_() const { + auto val_2 = _mm512_mul_pd(values, values); // a*a b*b + return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m512d abs_() const { + auto real = _mm512_movedup_pd(values); // real real + // movehdup_pd does not exist... + auto imag = _mm512_permute_pd(values, 0xff); // imag imag + return Sleef_hypotd8_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm512_and_pd(abs_(), real_mask); // abs 0 + } + __m512d angle_() const { + // angle = atan2(b/a) + auto b_a = _mm512_permute_pd(values, 0x55); // b a + return Sleef_atan2d8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle + return _mm512_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_pd(); + auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ); + auto div = _mm512_div_pd(values, abs); + return _mm512_mask_blend_pd(mask, div, zero); + } + __m512d real_() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm512_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512d imag_() const { + const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF)); + return _mm512_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_pd(imag_(), 0x55); // b a + } + __m512d conj_() const { + const __m512d sign_mask = + _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm512_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512d log2_ = _mm512_set1_pd(std::log(2)); + return _mm512_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m512d log10_ = _mm512_set1_pd(std::log(10)); + return _mm512_div_pd(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m512d one = _mm512_set1_pd(1); + + // auto conj = conj_(); + // auto b_a = _mm512_permute_pd(conj, 0x55); //-b a + // auto ab = _mm512_mul_pd(conj, b_a); //-ab + // -ab auto im = _mm512_add_pd(ab, ab); //-2ab -2ab + + // auto val_2 = _mm512_mul_pd(values, values); // a*a + // b*b auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b + // b*b-a*a re = _mm512_sub_pd(one, re); + + // auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); + // //sqrt(re + i*im) auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); + // //ln(iz + sqrt()) return Vectorized(_mm512_permute_pd(ln.values, + // 0x55)).conj(); //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m512d pi_2 = + _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0); + return _mm512_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expd8_u10(values); //exp(a) exp(b) exp = + // _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm512_mask_blend_pd(0xAA, + // _mm512_permute_pd(sin_cos.y, 0x55), + // sin_cos.x); //cos(b) + // sin(b) + // return _mm512_mul_pd(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m512d ln_2 = _mm512_set1_pd(c10::ln_2); + Vectorized> scaled_values = + _mm512_mul_pd(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized> floor() const { + return _mm512_floor_pd(values); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_pd(); + return _mm512_sub_pd(zero, values); + } + Vectorized> round() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator!=( + const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator<( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512d sign_mask = + _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm512_mul_pd(a, b); // ac bd + + auto d_c = _mm512_permute_pd(b, 0x55); // d c + d_c = _mm512_xor_pd(sign_mask, d_c); // d -c + auto ad_bc = _mm512_mul_pd(a, d_c); // ad -bc + + auto ret = Vectorized>::hsub_pd( + ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm512_set1_pd(-0.f); + // auto fabs_cd = _mm512_andnot_pd(mask, b); // |c| |d| + // auto fabs_dc = _mm512_permute_pd(fabs_cd, 0x55); // |d| |c| + // auto scale = _mm512_rcp14_pd(_mm512_max_pd(fabs_cd, fabs_dc)); // 1/sc + // 1/sc auto a2 = _mm512_mul_pd(a, scale); // a/sc b/sc auto b2 = + // _mm512_mul_pd(b, scale); // c/sc d/sc auto acbd2 = + // _mm512_mul_pd(a2, b2); + + // const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); auto dc2 = _mm512_permute_pd(b2, 0x55); // d/sc c/sc + // dc2 = _mm512_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm512_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = Vectorized>::hadd_pd(acbd2, adbc2); + // //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm512_div_pd(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm512_loadu_pd(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); auto c_d = _mm512_xor_pd(sign_mask, values); //c -d + // return _mm512_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm512_add_pd(i, values)); // a + // 1+b auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(max, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(min, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_pd(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm512_set1_pd(1.0)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm512_set1_pd(1.0)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h new file mode 100644 index 0000000000000000000000000000000000000000..604a89535821f604672b1e8ec55ee6fba2838bdb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -0,0 +1,1222 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m512 values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm512_setr_ps( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4, + c10::complex val5, + c10::complex val6, + c10::complex val7, + c10::complex val8) { + values = _mm512_setr_ps( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag(), + val5.real(), + val5.imag(), + val6.real(), + val6.imag(), + val7.real(), + val7.imag(), + val8.real(), + val8.imag()); + } + operator __m512() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 256, "Unexpected mask value"); + // The compiler would hopefully convert this switch condition + // into a jump table + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_ps(0x03, a.values, b.values); + case 2: + return _mm512_mask_blend_ps(0x0C, a.values, b.values); + case 3: + return _mm512_mask_blend_ps(0x0F, a.values, b.values); + case 4: + return _mm512_mask_blend_ps(0x30, a.values, b.values); + case 5: + return _mm512_mask_blend_ps(0x33, a.values, b.values); + case 6: + return _mm512_mask_blend_ps(0x3C, a.values, b.values); + case 7: + return _mm512_mask_blend_ps(0x3F, a.values, b.values); + case 8: + return _mm512_mask_blend_ps(0xC0, a.values, b.values); + case 9: + return _mm512_mask_blend_ps(0xC3, a.values, b.values); + case 10: + return _mm512_mask_blend_ps(0xCC, a.values, b.values); + case 11: + return _mm512_mask_blend_ps(0xCF, a.values, b.values); + case 12: + return _mm512_mask_blend_ps(0xF0, a.values, b.values); + case 13: + return _mm512_mask_blend_ps(0xF3, a.values, b.values); + case 14: + return _mm512_mask_blend_ps(0xFC, a.values, b.values); + case 15: + return _mm512_mask_blend_ps(0xFF, a.values, b.values); + case 16: + return _mm512_mask_blend_ps(0x300, a.values, b.values); + case 17: + return _mm512_mask_blend_ps(0x303, a.values, b.values); + case 18: + return _mm512_mask_blend_ps(0x30C, a.values, b.values); + case 19: + return _mm512_mask_blend_ps(0x30F, a.values, b.values); + case 20: + return _mm512_mask_blend_ps(0x330, a.values, b.values); + case 21: + return _mm512_mask_blend_ps(0x333, a.values, b.values); + case 22: + return _mm512_mask_blend_ps(0x33C, a.values, b.values); + case 23: + return _mm512_mask_blend_ps(0x33F, a.values, b.values); + case 24: + return _mm512_mask_blend_ps(0x3C0, a.values, b.values); + case 25: + return _mm512_mask_blend_ps(0x3C3, a.values, b.values); + case 26: + return _mm512_mask_blend_ps(0x3CC, a.values, b.values); + case 27: + return _mm512_mask_blend_ps(0x3CF, a.values, b.values); + case 28: + return _mm512_mask_blend_ps(0x3F0, a.values, b.values); + case 29: + return _mm512_mask_blend_ps(0x3F3, a.values, b.values); + case 30: + return _mm512_mask_blend_ps(0x3FC, a.values, b.values); + case 31: + return _mm512_mask_blend_ps(0x3FF, a.values, b.values); + case 32: + return _mm512_mask_blend_ps(0xC00, a.values, b.values); + case 33: + return _mm512_mask_blend_ps(0xC03, a.values, b.values); + case 34: + return _mm512_mask_blend_ps(0xC0C, a.values, b.values); + case 35: + return _mm512_mask_blend_ps(0xC0F, a.values, b.values); + case 36: + return _mm512_mask_blend_ps(0xC30, a.values, b.values); + case 37: + return _mm512_mask_blend_ps(0xC33, a.values, b.values); + case 38: + return _mm512_mask_blend_ps(0xC3C, a.values, b.values); + case 39: + return _mm512_mask_blend_ps(0xC3F, a.values, b.values); + case 40: + return _mm512_mask_blend_ps(0xCC0, a.values, b.values); + case 41: + return _mm512_mask_blend_ps(0xCC3, a.values, b.values); + case 42: + return _mm512_mask_blend_ps(0xCCC, a.values, b.values); + case 43: + return _mm512_mask_blend_ps(0xCCF, a.values, b.values); + case 44: + return _mm512_mask_blend_ps(0xCF0, a.values, b.values); + case 45: + return _mm512_mask_blend_ps(0xCF3, a.values, b.values); + case 46: + return _mm512_mask_blend_ps(0xCFC, a.values, b.values); + case 47: + return _mm512_mask_blend_ps(0xCFF, a.values, b.values); + case 48: + return _mm512_mask_blend_ps(0xF00, a.values, b.values); + case 49: + return _mm512_mask_blend_ps(0xF03, a.values, b.values); + case 50: + return _mm512_mask_blend_ps(0xF0C, a.values, b.values); + case 51: + return _mm512_mask_blend_ps(0xF0F, a.values, b.values); + case 52: + return _mm512_mask_blend_ps(0xF30, a.values, b.values); + case 53: + return _mm512_mask_blend_ps(0xF33, a.values, b.values); + case 54: + return _mm512_mask_blend_ps(0xF3C, a.values, b.values); + case 55: + return _mm512_mask_blend_ps(0xF3F, a.values, b.values); + case 56: + return _mm512_mask_blend_ps(0xFC0, a.values, b.values); + case 57: + return _mm512_mask_blend_ps(0xFC3, a.values, b.values); + case 58: + return _mm512_mask_blend_ps(0xFCC, a.values, b.values); + case 59: + return _mm512_mask_blend_ps(0xFCF, a.values, b.values); + case 60: + return _mm512_mask_blend_ps(0xFF0, a.values, b.values); + case 61: + return _mm512_mask_blend_ps(0xFF3, a.values, b.values); + case 62: + return _mm512_mask_blend_ps(0xFFC, a.values, b.values); + case 63: + return _mm512_mask_blend_ps(0xFFF, a.values, b.values); + case 64: + return _mm512_mask_blend_ps(0x3000, a.values, b.values); + case 65: + return _mm512_mask_blend_ps(0x3003, a.values, b.values); + case 66: + return _mm512_mask_blend_ps(0x300C, a.values, b.values); + case 67: + return _mm512_mask_blend_ps(0x300F, a.values, b.values); + case 68: + return _mm512_mask_blend_ps(0x3030, a.values, b.values); + case 69: + return _mm512_mask_blend_ps(0x3033, a.values, b.values); + case 70: + return _mm512_mask_blend_ps(0x303C, a.values, b.values); + case 71: + return _mm512_mask_blend_ps(0x303F, a.values, b.values); + case 72: + return _mm512_mask_blend_ps(0x30C0, a.values, b.values); + case 73: + return _mm512_mask_blend_ps(0X30C3, a.values, b.values); + case 74: + return _mm512_mask_blend_ps(0x30CC, a.values, b.values); + case 75: + return _mm512_mask_blend_ps(0x30CF, a.values, b.values); + case 76: + return _mm512_mask_blend_ps(0x30F0, a.values, b.values); + case 77: + return _mm512_mask_blend_ps(0x30F3, a.values, b.values); + case 78: + return _mm512_mask_blend_ps(0x30FC, a.values, b.values); + case 79: + return _mm512_mask_blend_ps(0x30FF, a.values, b.values); + case 80: + return _mm512_mask_blend_ps(0x3300, a.values, b.values); + case 81: + return _mm512_mask_blend_ps(0X3303, a.values, b.values); + case 82: + return _mm512_mask_blend_ps(0x330C, a.values, b.values); + case 83: + return _mm512_mask_blend_ps(0x330F, a.values, b.values); + case 84: + return _mm512_mask_blend_ps(0x3330, a.values, b.values); + case 85: + return _mm512_mask_blend_ps(0x3333, a.values, b.values); + case 86: + return _mm512_mask_blend_ps(0x333C, a.values, b.values); + case 87: + return _mm512_mask_blend_ps(0X333F, a.values, b.values); + case 88: + return _mm512_mask_blend_ps(0x33C0, a.values, b.values); + case 89: + return _mm512_mask_blend_ps(0x33C3, a.values, b.values); + case 90: + return _mm512_mask_blend_ps(0x33CC, a.values, b.values); + case 91: + return _mm512_mask_blend_ps(0x33CF, a.values, b.values); + case 92: + return _mm512_mask_blend_ps(0x33F0, a.values, b.values); + case 93: + return _mm512_mask_blend_ps(0x33F3, a.values, b.values); + case 94: + return _mm512_mask_blend_ps(0x33FC, a.values, b.values); + case 95: + return _mm512_mask_blend_ps(0x33FF, a.values, b.values); + case 96: + return _mm512_mask_blend_ps(0X3C00, a.values, b.values); + case 97: + return _mm512_mask_blend_ps(0x3C03, a.values, b.values); + case 98: + return _mm512_mask_blend_ps(0x3C0C, a.values, b.values); + case 99: + return _mm512_mask_blend_ps(0x3C0F, a.values, b.values); + case 100: + return _mm512_mask_blend_ps(0x3C30, a.values, b.values); + case 101: + return _mm512_mask_blend_ps(0x3C33, a.values, b.values); + case 102: + return _mm512_mask_blend_ps(0x3C3C, a.values, b.values); + case 103: + return _mm512_mask_blend_ps(0x3C3F, a.values, b.values); + case 104: + return _mm512_mask_blend_ps(0x3CC0, a.values, b.values); + case 105: + return _mm512_mask_blend_ps(0x3CC3, a.values, b.values); + case 106: + return _mm512_mask_blend_ps(0x3CCC, a.values, b.values); + case 107: + return _mm512_mask_blend_ps(0x3CCF, a.values, b.values); + case 108: + return _mm512_mask_blend_ps(0x3CF0, a.values, b.values); + case 109: + return _mm512_mask_blend_ps(0x3CF3, a.values, b.values); + case 110: + return _mm512_mask_blend_ps(0x3CFC, a.values, b.values); + case 111: + return _mm512_mask_blend_ps(0x3CFF, a.values, b.values); + case 112: + return _mm512_mask_blend_ps(0x3F00, a.values, b.values); + case 113: + return _mm512_mask_blend_ps(0x3F03, a.values, b.values); + case 114: + return _mm512_mask_blend_ps(0x3F0C, a.values, b.values); + case 115: + return _mm512_mask_blend_ps(0x3F0F, a.values, b.values); + case 116: + return _mm512_mask_blend_ps(0x3F30, a.values, b.values); + case 117: + return _mm512_mask_blend_ps(0x3F33, a.values, b.values); + case 118: + return _mm512_mask_blend_ps(0x3F3C, a.values, b.values); + case 119: + return _mm512_mask_blend_ps(0x3F3F, a.values, b.values); + case 120: + return _mm512_mask_blend_ps(0x3FC0, a.values, b.values); + case 121: + return _mm512_mask_blend_ps(0x3FC3, a.values, b.values); + case 122: + return _mm512_mask_blend_ps(0x3FCC, a.values, b.values); + case 123: + return _mm512_mask_blend_ps(0x3FCF, a.values, b.values); + case 124: + return _mm512_mask_blend_ps(0x3FF0, a.values, b.values); + case 125: + return _mm512_mask_blend_ps(0x3FF3, a.values, b.values); + case 126: + return _mm512_mask_blend_ps(0x3FFC, a.values, b.values); + case 127: + return _mm512_mask_blend_ps(0x3FFF, a.values, b.values); + case 128: + return _mm512_mask_blend_ps(0xC000, a.values, b.values); + case 129: + return _mm512_mask_blend_ps(0xC003, a.values, b.values); + case 130: + return _mm512_mask_blend_ps(0xC00C, a.values, b.values); + case 131: + return _mm512_mask_blend_ps(0xC00F, a.values, b.values); + case 132: + return _mm512_mask_blend_ps(0xC030, a.values, b.values); + case 133: + return _mm512_mask_blend_ps(0xC033, a.values, b.values); + case 134: + return _mm512_mask_blend_ps(0xC03C, a.values, b.values); + case 135: + return _mm512_mask_blend_ps(0xC03F, a.values, b.values); + case 136: + return _mm512_mask_blend_ps(0xC0C0, a.values, b.values); + case 137: + return _mm512_mask_blend_ps(0xC0C3, a.values, b.values); + case 138: + return _mm512_mask_blend_ps(0xC0CC, a.values, b.values); + case 139: + return _mm512_mask_blend_ps(0xC0CF, a.values, b.values); + case 140: + return _mm512_mask_blend_ps(0xC0F0, a.values, b.values); + case 141: + return _mm512_mask_blend_ps(0xC0F3, a.values, b.values); + case 142: + return _mm512_mask_blend_ps(0xC0FC, a.values, b.values); + case 143: + return _mm512_mask_blend_ps(0xC0FF, a.values, b.values); + case 144: + return _mm512_mask_blend_ps(0xC300, a.values, b.values); + case 145: + return _mm512_mask_blend_ps(0xC303, a.values, b.values); + case 146: + return _mm512_mask_blend_ps(0xC30C, a.values, b.values); + case 147: + return _mm512_mask_blend_ps(0xC30F, a.values, b.values); + case 148: + return _mm512_mask_blend_ps(0xC330, a.values, b.values); + case 149: + return _mm512_mask_blend_ps(0xC333, a.values, b.values); + case 150: + return _mm512_mask_blend_ps(0xC33C, a.values, b.values); + case 151: + return _mm512_mask_blend_ps(0xC33F, a.values, b.values); + case 152: + return _mm512_mask_blend_ps(0xC3C0, a.values, b.values); + case 153: + return _mm512_mask_blend_ps(0xC3C3, a.values, b.values); + case 154: + return _mm512_mask_blend_ps(0xC3CC, a.values, b.values); + case 155: + return _mm512_mask_blend_ps(0xC3CF, a.values, b.values); + case 156: + return _mm512_mask_blend_ps(0xC3F0, a.values, b.values); + case 157: + return _mm512_mask_blend_ps(0xC3F3, a.values, b.values); + case 158: + return _mm512_mask_blend_ps(0xC3FC, a.values, b.values); + case 159: + return _mm512_mask_blend_ps(0xC3FF, a.values, b.values); + case 160: + return _mm512_mask_blend_ps(0xCC00, a.values, b.values); + case 161: + return _mm512_mask_blend_ps(0xCC03, a.values, b.values); + case 162: + return _mm512_mask_blend_ps(0xCC0C, a.values, b.values); + case 163: + return _mm512_mask_blend_ps(0xCC0F, a.values, b.values); + case 164: + return _mm512_mask_blend_ps(0xCC30, a.values, b.values); + case 165: + return _mm512_mask_blend_ps(0xCC33, a.values, b.values); + case 166: + return _mm512_mask_blend_ps(0xCC3C, a.values, b.values); + case 167: + return _mm512_mask_blend_ps(0xCC3F, a.values, b.values); + case 168: + return _mm512_mask_blend_ps(0xCCC0, a.values, b.values); + case 169: + return _mm512_mask_blend_ps(0xCCC3, a.values, b.values); + case 170: + return _mm512_mask_blend_ps(0xCCCC, a.values, b.values); + case 171: + return _mm512_mask_blend_ps(0xCCCF, a.values, b.values); + case 172: + return _mm512_mask_blend_ps(0xCCF0, a.values, b.values); + case 173: + return _mm512_mask_blend_ps(0xCCF3, a.values, b.values); + case 174: + return _mm512_mask_blend_ps(0xCCFC, a.values, b.values); + case 175: + return _mm512_mask_blend_ps(0xCCFF, a.values, b.values); + case 176: + return _mm512_mask_blend_ps(0xCF00, a.values, b.values); + case 177: + return _mm512_mask_blend_ps(0xCF03, a.values, b.values); + case 178: + return _mm512_mask_blend_ps(0xCF0C, a.values, b.values); + case 179: + return _mm512_mask_blend_ps(0xCF0F, a.values, b.values); + case 180: + return _mm512_mask_blend_ps(0xCF30, a.values, b.values); + case 181: + return _mm512_mask_blend_ps(0xCF33, a.values, b.values); + case 182: + return _mm512_mask_blend_ps(0xCF3C, a.values, b.values); + case 183: + return _mm512_mask_blend_ps(0xCF3F, a.values, b.values); + case 184: + return _mm512_mask_blend_ps(0xCFC0, a.values, b.values); + case 185: + return _mm512_mask_blend_ps(0xCFC3, a.values, b.values); + case 186: + return _mm512_mask_blend_ps(0xCFCC, a.values, b.values); + case 187: + return _mm512_mask_blend_ps(0xCFCF, a.values, b.values); + case 188: + return _mm512_mask_blend_ps(0xCFF0, a.values, b.values); + case 189: + return _mm512_mask_blend_ps(0xCFF3, a.values, b.values); + case 190: + return _mm512_mask_blend_ps(0xCFFC, a.values, b.values); + case 191: + return _mm512_mask_blend_ps(0xCFFF, a.values, b.values); + case 192: + return _mm512_mask_blend_ps(0xF000, a.values, b.values); + case 193: + return _mm512_mask_blend_ps(0xF003, a.values, b.values); + case 194: + return _mm512_mask_blend_ps(0xF00C, a.values, b.values); + case 195: + return _mm512_mask_blend_ps(0xF00F, a.values, b.values); + case 196: + return _mm512_mask_blend_ps(0xF030, a.values, b.values); + case 197: + return _mm512_mask_blend_ps(0xF033, a.values, b.values); + case 198: + return _mm512_mask_blend_ps(0xF03C, a.values, b.values); + case 199: + return _mm512_mask_blend_ps(0xF03F, a.values, b.values); + case 200: + return _mm512_mask_blend_ps(0XF0C0, a.values, b.values); + case 201: + return _mm512_mask_blend_ps(0xF0C3, a.values, b.values); + case 202: + return _mm512_mask_blend_ps(0xF0CC, a.values, b.values); + case 203: + return _mm512_mask_blend_ps(0xF0CF, a.values, b.values); + case 204: + return _mm512_mask_blend_ps(0xF0F0, a.values, b.values); + case 205: + return _mm512_mask_blend_ps(0xF0F3, a.values, b.values); + case 206: + return _mm512_mask_blend_ps(0xF0FC, a.values, b.values); + case 207: + return _mm512_mask_blend_ps(0xF0FF, a.values, b.values); + case 208: + return _mm512_mask_blend_ps(0XF300, a.values, b.values); + case 209: + return _mm512_mask_blend_ps(0xF303, a.values, b.values); + case 210: + return _mm512_mask_blend_ps(0xF30C, a.values, b.values); + case 211: + return _mm512_mask_blend_ps(0xF30F, a.values, b.values); + case 212: + return _mm512_mask_blend_ps(0xF330, a.values, b.values); + case 213: + return _mm512_mask_blend_ps(0xF333, a.values, b.values); + case 214: + return _mm512_mask_blend_ps(0XF33C, a.values, b.values); + case 215: + return _mm512_mask_blend_ps(0xF33F, a.values, b.values); + case 216: + return _mm512_mask_blend_ps(0xF3C0, a.values, b.values); + case 217: + return _mm512_mask_blend_ps(0xF3C3, a.values, b.values); + case 218: + return _mm512_mask_blend_ps(0xF3CC, a.values, b.values); + case 219: + return _mm512_mask_blend_ps(0xF3CF, a.values, b.values); + case 220: + return _mm512_mask_blend_ps(0xF3F0, a.values, b.values); + case 221: + return _mm512_mask_blend_ps(0xF3F3, a.values, b.values); + case 222: + return _mm512_mask_blend_ps(0xF3FC, a.values, b.values); + case 223: + return _mm512_mask_blend_ps(0XF3FF, a.values, b.values); + case 224: + return _mm512_mask_blend_ps(0xFC00, a.values, b.values); + case 225: + return _mm512_mask_blend_ps(0xFC03, a.values, b.values); + case 226: + return _mm512_mask_blend_ps(0xFC0C, a.values, b.values); + case 227: + return _mm512_mask_blend_ps(0xFC0F, a.values, b.values); + case 228: + return _mm512_mask_blend_ps(0xFC30, a.values, b.values); + case 229: + return _mm512_mask_blend_ps(0xFC33, a.values, b.values); + case 230: + return _mm512_mask_blend_ps(0xFC3C, a.values, b.values); + case 231: + return _mm512_mask_blend_ps(0xFC3F, a.values, b.values); + case 232: + return _mm512_mask_blend_ps(0xFCC0, a.values, b.values); + case 233: + return _mm512_mask_blend_ps(0xFCC3, a.values, b.values); + case 234: + return _mm512_mask_blend_ps(0xFCCC, a.values, b.values); + case 235: + return _mm512_mask_blend_ps(0xFCCF, a.values, b.values); + case 236: + return _mm512_mask_blend_ps(0xFCF0, a.values, b.values); + case 237: + return _mm512_mask_blend_ps(0xFCF3, a.values, b.values); + case 238: + return _mm512_mask_blend_ps(0xFCFC, a.values, b.values); + case 239: + return _mm512_mask_blend_ps(0xFCFF, a.values, b.values); + case 240: + return _mm512_mask_blend_ps(0xFF00, a.values, b.values); + case 241: + return _mm512_mask_blend_ps(0xFF03, a.values, b.values); + case 242: + return _mm512_mask_blend_ps(0xFF0C, a.values, b.values); + case 243: + return _mm512_mask_blend_ps(0xFF0F, a.values, b.values); + case 244: + return _mm512_mask_blend_ps(0xFF30, a.values, b.values); + case 245: + return _mm512_mask_blend_ps(0xFF33, a.values, b.values); + case 246: + return _mm512_mask_blend_ps(0xFF3C, a.values, b.values); + case 247: + return _mm512_mask_blend_ps(0xFF3F, a.values, b.values); + case 248: + return _mm512_mask_blend_ps(0xFFC0, a.values, b.values); + case 249: + return _mm512_mask_blend_ps(0xFFC3, a.values, b.values); + case 250: + return _mm512_mask_blend_ps(0xFFCC, a.values, b.values); + case 251: + return _mm512_mask_blend_ps(0xFFCF, a.values, b.values); + case 252: + return _mm512_mask_blend_ps(0xFFF0, a.values, b.values); + case 253: + return _mm512_mask_blend_ps(0xFFF3, a.values, b.values); + case 254: + return _mm512_mask_blend_ps(0xFFFC, a.values, b.values); + default: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_ps(mask.values, mask.values); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + step, + base + c10::complex(2) * step, + base + c10::complex(3) * step, + base + c10::complex(4) * step, + base + c10::complex(5) * step, + base + c10::complex(6) * step, + base + c10::complex(7) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2 * size()]; + _mm512_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512 hadd_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32( + 30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_add_ps( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + static inline __m512 hsub_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32( + 30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_sub_ps( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m512 abs_2_() const { + auto val_2 = _mm512_mul_ps(values, values); // a*a b*b + auto ret = hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return ret; + } + __m512 abs_() const { + auto real = _mm512_moveldup_ps(values); // real real + auto imag = _mm512_movehdup_ps(values); // imag imag + return Sleef_hypotf16_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm512_and_ps(abs_(), real_mask); // abs 0 + } + __m512 angle_() const { + // angle = atan2(b/a) + auto b_a = _mm512_permute_ps(values, 0xB1); // b a + return Sleef_atan2f16_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + auto angle = _mm512_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm512_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_ps(); + auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ); + auto div = _mm512_div_ps(values, abs); + return _mm512_mask_blend_ps(mask, div, zero); + } + __m512 real_() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm512_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512 imag_() const { + const __m512 imag_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF)); + return _mm512_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_ps(imag_(), 0xB1); // b a + } + __m512 conj_() const { + const __m512 sign_mask = _mm512_setr_ps( + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0); + return _mm512_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512 log2_ = _mm512_set1_ps(std::log(2)); + return _mm512_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m512 log10_ = _mm512_set1_ps(std::log(10)); + return _mm512_div_ps(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m512 one = _mm512_set1_ps(1); + + // auto conj = conj_(); + // auto b_a = _mm512_permute_ps(conj, 0xB1); //-b a + // auto ab = _mm512_mul_ps(conj, b_a); //-ab + // -ab auto im = _mm512_add_ps(ab, ab); //-2ab -2ab + + // auto val_2 = _mm512_mul_ps(values, values); // a*a + // b*b auto re = hsub_ps(val_2, _mm512_permute_ps(val_2, 0xB1)); // a*a-b*b + // b*b-a*a re = _mm512_sub_ps(one, re); + + // auto root = Vectorized(_mm512_mask_blend_ps(0xAAAA, re, im)).sqrt(); + // //sqrt(re + i*im) auto ln = Vectorized(_mm512_add_ps(b_a, root)).log(); + // //ln(iz + sqrt()) return Vectorized(_mm512_permute_ps(ln.values, + // 0xB1)).conj(); //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expf16_u10(values); //exp(a) exp(b) exp = + // _mm512_mask_blend_ps(0xAAAA, exp, _mm512_permute_ps(exp, 0xB1)); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosf16_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm512_mask_blend_ps(0xAAAA, + // _mm512_permute_ps(sin_cos.y, 0xB1), + // sin_cos.x); //cos(b) + // sin(b) + // return _mm512_mul_ps(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m512 ln_2 = _mm512_set1_ps(c10::ln_2); + Vectorized> scaled_values = _mm512_mul_ps(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized> floor() const { + return _mm512_floor_ps(values); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_ps(); + return _mm512_sub_ps(zero, values); + } + Vectorized> round() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator!=( + const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator<( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512 sign_mask = _mm512_setr_ps( + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0); + auto ac_bd = _mm512_mul_ps(a, b); // ac bd + + auto d_c = _mm512_permute_ps(b, 0xB1); // d c + d_c = _mm512_xor_ps(sign_mask, d_c); // d -c + auto ad_bc = _mm512_mul_ps(a, d_c); // ad -bc + + auto ret = Vectorized>::hsub_ps( + ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm512_set1_ps(-0.f); + // auto fabs_cd = _mm512_andnot_ps(mask, b); // |c| |d| + // auto fabs_dc = _mm512_permute_ps(fabs_cd, 0xB1); // |d| |c| + // auto scale = _mm512_rcp14_ps(_mm512_max_ps(fabs_cd, fabs_dc)); // 1/sc + // 1/sc auto a2 = _mm512_mul_ps(a, scale); // a/sc b/sc auto b2 = + // _mm512_mul_ps(b, scale); // c/sc d/sc auto acbd2 = + // _mm512_mul_ps(a2, b2); + + // const __m512 sign_mask = _mm512_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0, + // -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); + // auto dc2 = _mm512_permute_ps(b2, 0xB1); // d/sc c/sc + // dc2 = _mm512_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm512_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = Vectorized>::hadd_ps(acbd2, adbc2); + // //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm512_div_ps(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm512_loadu_ps(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0, + // 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); + // auto c_d = _mm512_xor_ps(sign_mask, values); //c -d + // return _mm512_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m512 i = _mm512_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + // 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm512_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5, + // 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm512_add_ps(i, values)); // a + // 1+b auto sub = Vectorized(_mm512_sub_ps(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(max, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(min, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_ps(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm512_set1_ps(1.0f)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm512_set1_ps(1.0f)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..a4adc222fa1887a776ad16f10693a9e3334b481d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h @@ -0,0 +1,340 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + __m512 value; + cvtbf16_fp32(_mm512_castsi512_si256(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + __m512 value; + cvtfp16_fp32(_mm512_castsi512_si256(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = _mm512_castsi256_si512(cvtfp32_bf16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = _mm512_castsi256_si512(cvtfp32_fp16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm512_cvtepi64_ps(src[0]); + auto high = _mm512_cvtepi64_ps(src[1]); + return Vectorized( + _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm512_cvt_roundps_epi64( + _mm512_castps512_ps256(src[0]), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + result[1] = _mm512_cvt_roundps_epi64( + _mm512_extractf32x8_ps(src[0], 1), + _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm512_cvtepi64_epi32(src[0]); + auto high = _mm512_cvtepi64_epi32(src[1]); + return Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src[0])); + result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src[0], 1)); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_castsi512_si128(src[0]); + return Vectorized(_mm512_cvtepi8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_castsi512_si128(src[0]); + return Vectorized(_mm512_cvtepu8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm512_cvttps_epi32(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm512_cvtepi32_ps(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src256 = _mm512_castsi512_si256(src[0]); + return Vectorized(_mm512_cvtepu8_epi16(src256)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_cvtepi32_epi8(src[0]); + return Vectorized(_mm512_castsi128_si512(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src256 = _mm512_cvtepi16_epi8(src[0]); + return Vectorized(_mm512_castsi256_si512(src256)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + src_t, + 1, + typename std::enable_if_t< + (is_reduced_floating_point_v && is_8bit_integer_v) || + (is_reduced_floating_point_v && is_8bit_integer_v), + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2)); + __m512 result = _mm512_insertf32x4( + _mm512_castsi512_ps(vec1), + lane2, + 1); // Insert lane2 into the second 128-bit lane + return at::vec::Vectorized(_mm512_castps_si512(result)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_float_to_int8(src[0]); + } +}; + +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + __m512i src2 = + _mm512_castsi128_si512(_mm_castps_si128(_mm512_extractf32x4_ps( + _mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane + )); + return VectorizedN( + convert_int8_to_float(src[0]), + convert_int8_to_float(src2)); + } +}; + +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_to_float(src[0]); + } +}; + +template +struct VecConvert< + dst_t, + 1, + int64_t, + 2, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const VectorizedN& src) { + return VecConvert::apply( + VecConvert::apply(src)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + at::vec::Vectorized src = src_n[0]; + __m128i res128 = cvtfp32_fp8e4m3(src); + return at::vec::Vectorized(_mm512_castsi128_si512(res128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + // cvt first 16x8 bits from Float8_e4m3fn to float + at::vec::Vectorized src = src_n[0]; + __m512 result; + cvtfp8e4m3_fp32(_mm512_castsi512_si128(src), result); + return at::vec::Vectorized(result); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + at::vec::Vectorized src = src_n[0]; + __m128i res128 = cvtfp32_fp8e5m2(src); + return at::vec::Vectorized(_mm512_castsi128_si512(res128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + // cvt first 16x8 bits from Float8_e5m2 to float + at::vec::Vectorized src = src_n[0]; + __m512 result; + cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result); + return at::vec::Vectorized(result); + } +}; + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h new file mode 100644 index 0000000000000000000000000000000000000000..ca5970ccb1f9e9b41561bab799c36d851f1e3b21 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h @@ -0,0 +1,545 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + // values needs to be public for compilation with clang + // as vec512.h uses it + __m512d values; + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(double val) { + values = _mm512_set1_pd(val); + } + Vectorized( + double val1, + double val2, + double val3, + double val4, + double val5, + double val6, + double val7, + double val8) { + values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8); + } + operator __m512d() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mask_blend_pd(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask( + _mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __mmask8 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_pd(mask, ptr); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + __mmask8 mask = (1ULL << count) - 1; + _mm512_mask_storeu_pd(reinterpret_cast(ptr), mask, values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto cmp_mask = + _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + bool has_inf_nan() const { + __m512d self_sub = _mm512_sub_pd(values, values); + return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & + 0x7777777777777777) != 0; + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_pd(-0.f); + return _mm512_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm512_castsi512_pd(zero_vector); + const auto nan_vec = _mm512_set1_pd(NAN); + const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ); + const auto not_nan = + _mm512_mask_set1_epi64(zero_vector, not_nan_mask, 0xFFFFFFFFFFFFFFFF); + const auto nan_mask = + _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_pd(c10::pi); + + const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd8_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshd8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind8_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhd8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand8_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhd8_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2d8_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignd8(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd8_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd8_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2d8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d8_u10(values)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd8(values, q)); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotd8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd8_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind8_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd8_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd8_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized floor() const { + return _mm512_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm512_xor_pd(_mm512_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterd8(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd8_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad8_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm512_div_pd(_mm512_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powd8_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mul_pd(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm512_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized max = _mm512_max_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized min = _mm512_min_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm512_min_pd(max, _mm512_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm512_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm512_min_pd(max, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_pd(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmadd_pd(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmsub_pd(a, b, c); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h new file mode 100644 index 0000000000000000000000000000000000000000..b192bb566dfd47c5bef791c1f31137230bdc80f1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h @@ -0,0 +1,868 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + static constexpr __m512i zero_vec{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + __m512 values; + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(float val) { + values = _mm512_set1_ps(val); + } + Vectorized( + float val1, + float val2, + float val3, + float val4, + float val5, + float val6, + float val7, + float val8, + float val9, + float val10, + float val11, + float val12, + float val13, + float val14, + float val15, + float val16) { + values = _mm512_setr_ps( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + Vectorized(const float (&arr)[16]) + : Vectorized( + arr[0], + arr[1], + arr[2], + arr[3], + arr[4], + arr[5], + arr[6], + arr[7], + arr[8], + arr[9], + arr[10], + arr[11], + arr[12], + arr[13], + arr[14], + arr[15]) {} + operator __m512() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mask_blend_ps(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __mmask16 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_ps(mask, ptr); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + __mmask16 mask = (1ULL << count) - 1; + _mm512_mask_storeu_ps(reinterpret_cast(ptr), mask, values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + bool has_inf_nan() const { + __m512 self_sub = _mm512_sub_ps(values, values); + return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & + 0x7777777777777777) != 0; + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_ps(-0.f); + return _mm512_andnot_ps(mask, values); + } + Vectorized angle() const { + __m512 zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto not_nan_vec = _mm512_mask_set1_epi32( + _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask( + _mm512_castsi512_ps(not_nan_vec), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf16_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshf16_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf16_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhf16_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf16_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhf16_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2f16_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignf16(values, sign)); + } + Vectorized erf() const { + // constants + const auto neg_zero_vec = _mm512_set1_ps(-0.f); + const auto one_vec = _mm512_set1_ps(1.0f); + const auto p = _mm512_set1_ps(0.3275911f); + const auto p1 = _mm512_set1_ps(0.254829592f); + const auto p2 = _mm512_set1_ps(-0.284496736f); + const auto p3 = _mm512_set1_ps(1.421413741f); + const auto p4 = _mm512_set1_ps(-1.453152027f); + const auto p5 = _mm512_set1_ps(1.061405429f); + // sign(x) + auto sign_mask = _mm512_and_ps(neg_zero_vec, values); + auto abs_vec = _mm512_abs_ps(values); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = _mm512_fmadd_ps(p, abs_vec, one_vec); + auto t = _mm512_div_ps(one_vec, tmp0); + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = _mm512_fmadd_ps(p5, t, p4); + auto tmp2 = _mm512_fmadd_ps(tmp1, t, p3); + auto tmp3 = _mm512_fmadd_ps(tmp2, t, p2); + auto r = _mm512_fmadd_ps(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = _mm512_mul_ps(values, values); + auto neg_pow_2 = _mm512_xor_ps(neg_zero_vec, pow_2); + // auto tmp4 = exp(neg_pow_2); + auto tmp4 = Vectorized(Sleef_expf16_u10(neg_pow_2)); + auto tmp5 = _mm512_xor_ps(neg_zero_vec, tmp4); + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = _mm512_mul_ps(tmp5, t); + auto tmp7 = _mm512_fmadd_ps(tmp6, r, one_vec); + return _mm512_xor_ps(sign_mask, tmp7); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf16_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf16_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2f16_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f16_u10(values)); + } + Vectorized exp_u20() const { + // A faster version of exp with ULP=20 + const __m512 vec_factorial_1 = + _mm512_set1_ps(0.999999701f); // 1/factorial(1) + const __m512 vec_factorial_2 = + _mm512_set1_ps(0.499991506f); // 1/factorial(2) + const __m512 vec_factorial_3 = + _mm512_set1_ps(0.166676521f); // 1/factorial(3) + const __m512 vec_factorial_4 = + _mm512_set1_ps(0.0418978221f); // 1/factorial(4) + const __m512 vec_factorial_5 = + _mm512_set1_ps(0.00828929059f); // 1/factorial(5) + const __m512 vec_exp_log2ef = + _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + const __m512 vec_half = _mm512_set1_ps(0.5f); + const __m512 vec_one = _mm512_set1_ps(1.f); + const __m512 vec_zero = _mm512_set1_ps(0.f); + const __m512 vec_two = _mm512_set1_ps(2.f); + const __m512 vec_ln2f = + _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + const __m512 vec_ln_flt_min = + _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); + const __m512 vec_ln_flt_max = + _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); + const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; + + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + auto less_ln_flt_min_mask = + _mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); + auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); + vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); + + // fx = floorf(x * log2ef + 0.5) + auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); + auto vec_fx_i = _mm512_cvt_roundps_epi32( + vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); + vec_fx = _mm512_cvtepi32_ps(vec_fx_i); + + // x = x - fx * ln2 + auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); + + // compute polynomial + auto vec_res = + _mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); + + // compute 2^(n-1) + auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); + auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); + auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); + vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); + auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); + vec_two_pow_n = + _mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero); + + // y = y * 2^n + vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); + vec_res = _mm512_mul_ps(vec_res, vec_two); + return vec_res; + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf16(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf16_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f16_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f16_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf16_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf16_u35(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf16_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf16_u35(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf16_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized floor() const { + return _mm512_floor_ps(values); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotf16_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm512_xor_ps(_mm512_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterf16(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf16_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf16_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf16_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm512_div_ps(_mm512_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powf16_u10(values, b)); + } + float reduce_add() const { + return _mm512_reduce_add_ps(values); + } + float reduce_max() const { + return _mm512_reduce_max_ps(values); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mul_ps(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm512_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto max = _mm512_max_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto min = _mm512_min_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm512_min_ps(max, _mm512_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm512_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm512_max_ps(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_ps(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmadd_ps(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmsub_ps(a, b, c); +} + +// TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen for micro gemm +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 +// kernel for transposing mxn where m, n <= 16 +// (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + N instructions +inline void transpose_block( + at::vec::VectorizedN& input, + int M = 16, + int N = 16) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_block expects M, N <= 16."); + // unpacking and interleaving 32-bit elements + __m512 temp[16]; + int i; + for (i = 0; i < (M + 1) / 2; ++i) { + temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]); + temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]); + } + for (i = i * 2; i < 16; ++i) { + temp[i] = _mm512_setzero_ps(); + } + + // unpacking and interleaving 64-bit elements + for (i = 0; i < (M + 3) / 4; ++i) { + input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd( + _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); + input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd( + _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); + input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd( + _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); + input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd( + _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); + } + + // shuffle 128-bits (composed of 4 32-bit elements) + for (i = 0; i < (M + 7) / 8; ++i) { + temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88); + temp[8 * i + 1] = + _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88); + temp[8 * i + 2] = + _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88); + temp[8 * i + 3] = + _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88); + temp[8 * i + 4] = + _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd); + temp[8 * i + 5] = + _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd); + temp[8 * i + 6] = + _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd); + temp[8 * i + 7] = + _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); + } + + for (i = 0; i < N; ++i) { + if (i < 8) { + input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); + } else { + input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); + } + } +} + +// TODO(jgong5): rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 +// kernel for transposing mxn where m, n <= 16 +// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions +inline void transpose_mxn_16x16( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst, + int M, + int N) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn expects M, N <= 16."); + // load from src to registers + at::vec::VectorizedN input; + int i; + if (N == 16) { + for (i = 0; i < M; ++i) { + input[i] = _mm512_loadu_ps(&src[i * ld_src]); + } + } else { + __mmask16 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); + } + } + for (; i < 16; ++i) { + // Not really needed but to avoid uninitialized variable warning. + // Shouldn't be much overhead because xor can be executed in parallel with + // other instructions. + input[i] = _mm512_setzero_ps(); + } + + transpose_block(input, M, N); + + // store from registers to dst + if (M == 16) { + for (i = 0; i < N; ++i) { + _mm512_storeu_ps(&dst[i * ld_dst], input[i]); + } + } else { + __mmask16 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); + } + } +} + +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst, + int M, + int N) { + int64_t i = 0; + for (; i < M / 16 * 16; i += 16) { + int64_t j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16); + } + // handle remainder j + int nrem = N - j; + if (nrem > 0) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); + } + } + // handle remainder i + int mrem = M - i; + if (mrem > 0) { + int j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); + } + // handle remainder j + int nrem = N - j; + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t, int> = 0> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h new file mode 100644 index 0000000000000000000000000000000000000000..12ee4c460641f42a66bdbbb3f9ccfb0c8a9a6cfa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h @@ -0,0 +1,661 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +static inline void cvtfp8e4m3_fp32(const __m128i& a, __m512& o) { + // Zero Extend + __m512i x = _mm512_cvtepu8_epi32(a); + __m512i val = _mm512_and_epi32( + _mm512_slli_epi32(x, 24), _mm512_set1_epi32(0x7FFFFFFF)); // nonsign_val + __m512i mant = + _mm512_and_si512(x, _mm512_set1_epi32(0x07)); // mantissa = x & 0x07 + __m512i exp = _mm512_and_si512( + _mm512_srli_epi32(x, 3), + _mm512_set1_epi32(0x0F)); // exp = (x >> 3) & 0x0F + __m512i sign = + _mm512_and_si512(x, _mm512_set1_epi32(0x80)); // sign = x & 0x80 + __m512i _zeros = _mm512_setzero_si512(); + + // --- Step 1: Calculate the renorm_shift + __m512i renorm_shift = _zeros; + // Denorm case (exp == 0 && mant != 0) --- + __mmask16 denormal_mask = _mm512_cmpeq_epi32_mask(exp, _zeros) & + _mm512_cmpneq_epi32_mask(mant, _zeros); + if (denormal_mask) { + // An alternative solution is as what scalar did in + // pytorch/c10/util/Float8_e4m3fn.h To count the num of leading zeros, since + // here we know the unsigned denorm value has zero sign and exp which is 5 + // leading zeros, we need to count the leading zero of mant (3bit) which may + // done through table lookup for example: const uint8_t lz_table[8] = {3, 2, + // 1, 1, 0, 0, 0, 0}; num_leading_zero = lz_table[mant] + 5; + + __m512i _ones = _mm512_set1_epi32(1); + __m512i _twos = _mm512_set1_epi32(2); + __m512i _threes = _mm512_set1_epi32(3); + + // Default leading zero number for denorm value is 1 = 5 - 4 + __m512i denorm_renorm_shift = _ones; + // For mant 001, leading zero number is 3 = 7 -4 + __mmask16 leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _ones); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _threes); + // For mant 010 and 011, leading zero number is 2 = 6 -4 + leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _twos); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _twos); + leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _threes); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _twos); + + renorm_shift = + _mm512_mask_mov_epi32(renorm_shift, denormal_mask, denorm_renorm_shift); + } + + // --- Step 2: calculate norm and denorm --- + __m512i norm_shifted = + _mm512_srli_epi32(_mm512_sllv_epi32(val, renorm_shift), 4); + // exponent bias adjustment: (0x78 - renorm_shift) << 23 + __m512i exp_bias = _mm512_slli_epi32( + _mm512_sub_epi32(_mm512_set1_epi32(0x78), renorm_shift), 23); + val = _mm512_add_epi32(norm_shifted, exp_bias); + + // --- Step 3: Nan case (exp == 0xF && mant == 0x07) --- + __mmask16 nan_mask = _mm512_cmpeq_epi32_mask(exp, _mm512_set1_epi32(0xF)) & + _mm512_cmpeq_epi32_mask(mant, _mm512_set1_epi32(0x07)); + if (nan_mask) { + const __m512i nan_values = _mm512_set1_epi32(0x7FC00000); + val = _mm512_mask_mov_epi32(val, nan_mask, nan_values); + } + + // --- Step 4: Zero case (exp == 0x00 && mant == 0x00) --- + __mmask16 zero_mask = _mm512_cmpeq_epi32_mask(exp, _zeros) & + _mm512_cmpeq_epi32_mask(mant, _zeros); + if (zero_mask) { + val = _mm512_mask_mov_epi32(val, zero_mask, _zeros); + } + + // --- Step 5: OR with sign (sign bit << 24 to get to bit 31) --- + val = _mm512_or_si512(val, _mm512_slli_epi32(sign, 24)); + + o = _mm512_castsi512_ps(val); +} + +static inline __m128i cvtfp32_fp8e4m3(const __m512& src) { + // cvt 16x32 from fp32 to fp8 e4m3 + const __m512i sign_mask = _mm512_set1_epi32(0x80000000); + const __m512i fp8_max = _mm512_set1_epi32(UINT32_C(1087) << 20); + const __m512i denorm_thresh = _mm512_set1_epi32(UINT32_C(121) << 23); + const __m512i denorm_mask = _mm512_set1_epi32(UINT32_C(141) << 23); + const __m512i bias_part1 = _mm512_set1_epi32((uint32_t)(7 - 127) << 23); + const __m512i rounding_bias = _mm512_set1_epi32(0x7FFFF); + __m512i f_bits = _mm512_castps_si512(src); + // Extract and save sign + __m512i sign = _mm512_and_epi32(f_bits, sign_mask); + f_bits = _mm512_xor_epi32(f_bits, sign); + + // Prepare result containers + __m512i result = _mm512_setzero_si512(); + + // Step 1: Handle case of overflow + // (f_bits >= fp8_max): set result = 0x7f + __mmask16 overflow_mask = _mm512_cmpge_epu32_mask(f_bits, fp8_max); + if (overflow_mask) { + result = _mm512_mask_set1_epi32(result, overflow_mask, 0x7f); + } + + // Step 2: Handle small numbers (denormals) + // Small numbers (f_bits < denorm_thresh) + __mmask16 denorm_thresh_mask = _mm512_cmplt_epu32_mask(f_bits, denorm_thresh); + + if (denorm_thresh_mask) { + __m512 small_input = _mm512_castsi512_ps(f_bits); + __m512 small_denorm = + _mm512_add_ps(small_input, _mm512_castsi512_ps(denorm_mask)); + __m512i small_denorm_bits = _mm512_castps_si512(small_denorm); + __m512i small_result = _mm512_sub_epi32(small_denorm_bits, denorm_mask); + result = _mm512_mask_mov_epi32(result, denorm_thresh_mask, small_result); + } + + // Step 3: Handle normal numbers + __mmask16 normal_mask = ~(overflow_mask | denorm_thresh_mask); + + if (normal_mask) { + // mant_odd = (f_bits >> 20) & 1 + __m512i mant_odd = + _mm512_and_epi32(_mm512_srli_epi32(f_bits, 20), _mm512_set1_epi32(1)); + // f_bits += bias_part1 + rounding_bias + __m512i rounded = _mm512_add_epi32(f_bits, bias_part1); + rounded = _mm512_add_epi32(rounded, rounding_bias); + // Add mant_odd + rounded = _mm512_add_epi32(rounded, mant_odd); + // Shift right by 20 bits + __m512i normal_result = _mm512_srli_epi32(rounded, 20); + result = _mm512_mask_mov_epi32(result, normal_mask, normal_result); + } + + // Merge back the sign + __m512i sign_shifted = _mm512_srli_epi32(sign, 24); + result = _mm512_or_epi32(result, sign_shifted); + + // Now result is 16 x 32-bit integers, but we only need 8-bit for each + __m512i packed = _mm512_and_si512(result, _mm512_set1_epi32(0xFF)); + + // Narrow 32-bit integers to 8-bit + return _mm512_cvtepi32_epi8(packed); +} + +static inline float fp8e4m3_to_fp32_scalar(uint8_t val) { + __m512i v = _mm512_set1_epi8(val); + __m128i v_128 = _mm512_castsi512_si128(v); + __m512 o; + cvtfp8e4m3_fp32(v_128, o); + return _mm512_cvtss_f32(o); +} + +static inline uint8_t fp32_to_fp8e4m3_scalar(float val) { + __m512 v = _mm512_set1_ps(val); + __m128i o = cvtfp32_fp8e4m3(v); + return static_cast(_mm_cvtsi128_si32(o)); +} + +static inline void cvtfp8e5m2_fp32(const __m128i& a, __m512& o) { + __m256i a_256 = _mm256_castsi128_si256(a); + __m512i a_512 = _mm512_cvtepu8_epi16(a_256); + a_512 = _mm512_slli_epi16(a_512, 8); + a_256 = _mm512_castsi512_si256(a_512); + cvtfp16_fp32(a_256, o); +} + +static inline __m128i cvtfp32_fp8e5m2(const __m512& src) { + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + // Cvt to bits + __m512i input_bits = _mm512_castps_si512(src); + __m512i result = _mm512_setzero_si512(); + + // Get the sign + __m512i sign = _mm512_and_si512(input_bits, _mm512_set1_epi32(0x80000000)); + + // Get the unsigned input + input_bits = _mm512_xor_si512(input_bits, sign); + + // Calculate the mask for inf, nan and denorm + __mmask16 greater_than_fp8_max = + _mm512_cmpge_epi32_mask(input_bits, _mm512_set1_epi32(fp8_max)); + __mmask16 greater_than_fp32_inf = + _mm512_cmpgt_epi32_mask(input_bits, _mm512_set1_epi32(fp32_inf)); + __mmask16 less_than_normal = _mm512_cmpgt_epi32_mask( + _mm512_set1_epi32((UINT32_C(113) << 23)), input_bits); + __m512i temp_bits_for_denorm = _mm512_setzero_si512(); + if (less_than_normal) { + __m512i denorm_mask_512i = _mm512_set1_epi32(denorm_mask); + temp_bits_for_denorm = _mm512_castps_si512(_mm512_add_ps( + _mm512_castsi512_ps(input_bits), + _mm512_castsi512_ps(denorm_mask_512i))); + temp_bits_for_denorm = + _mm512_sub_epi32(temp_bits_for_denorm, denorm_mask_512i); + } + + // Step 1: Norm Val + __m512i mant_odd_mask = + _mm512_and_epi32(_mm512_srli_epi32(input_bits, 21), _mm512_set1_epi32(1)); + input_bits = _mm512_add_epi32( + input_bits, _mm512_set1_epi32(((uint32_t)(15 - 127) << 23) + 0xFFFFF)); + input_bits = _mm512_add_epi32(input_bits, mant_odd_mask); + result = _mm512_srli_epi32(input_bits, 21); + + // Step 2: INF and NAN + if (greater_than_fp8_max) { + result = _mm512_mask_mov_epi32( + result, greater_than_fp8_max, _mm512_set1_epi8(0x7C)); + if (greater_than_fp32_inf) { + result = _mm512_mask_mov_epi32( + result, greater_than_fp32_inf, _mm512_set1_epi8(0x7F)); + } + } + + // Step 3: Denorm val + if (less_than_normal) { + result = + _mm512_mask_mov_epi32(result, less_than_normal, temp_bits_for_denorm); + } + + // Step 4: restore sign + result = _mm512_or_si512(result, _mm512_srli_epi32(sign, 24)); + + return _mm512_cvtepi32_epi8(result); +} + +static inline float fp8e5m2_to_fp32_scalar(uint8_t val) { + __m512i v = _mm512_set1_epi8(val); + __m128i v_128 = _mm512_castsi512_si128(v); + __m512 o; + cvtfp8e5m2_fp32(v_128, o); + return _mm512_cvtss_f32(o); +} + +static inline uint8_t fp32_to_fp8e5m2_scalar(float val) { + __m512 v = _mm512_set1_ps(val); + __m128i o = cvtfp32_fp8e5m2(v); + return static_cast(_mm_cvtsi128_si32(o)); +} + +template +class Vectorizedf8 { + static_assert( + std::integral_constant < bool, + std::is_same_v || std::is_same_v < T, + at::Float8_e5m2 >> ::value, + "Support only float8 e4m3."); + + private: + __m512i values; + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m512 a0, a1, a2, a3; + __m512 b0, b1, b2, b3; + __m512 o0, o1, o2, o3; + if constexpr (std::is_same_v) { + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 0), a0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 1), a1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 2), a2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 3), a3); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3); + } else { + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 0), a0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 1), a1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 2), a2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 3), a3); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3); + } + + o0 = op(a0, b0); + o1 = op(a1, b1); + o2 = op(a2, b2); + o3 = op(a3, b3); + __m128i o128_0, o128_1, o128_2, o128_3; + if constexpr (std::is_same_v) { + o128_0 = cvtfp32_fp8e4m3(o0); + o128_1 = cvtfp32_fp8e4m3(o1); + o128_2 = cvtfp32_fp8e4m3(o2); + o128_3 = cvtfp32_fp8e4m3(o3); + } else { + o128_0 = cvtfp32_fp8e5m2(o0); + o128_1 = cvtfp32_fp8e5m2(o1); + o128_2 = cvtfp32_fp8e5m2(o2); + o128_3 = cvtfp32_fp8e5m2(o3); + } + + __m512i result = _mm512_setzero_si512(); + result = _mm512_inserti32x4(result, o128_0, 0); + result = _mm512_inserti32x4(result, o128_1, 1); + result = _mm512_inserti32x4(result, o128_2, 2); + result = _mm512_inserti32x4(result, o128_3, 3); + + return result; + } + + public: + using value_type = uint8_t; + using size_type = int; + static constexpr size_type size() { + return 64; + } + Vectorizedf8() {} + Vectorizedf8(__m512i v) : values(v) {} + Vectorizedf8(T val) { + value_type uw = val.x; + values = _mm512_set1_epi8(uw); + } + operator __m512i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else if (count == 16) { + // Fast path if only load element number of 16 + __m128i input_128 = + _mm_loadu_si128(reinterpret_cast(ptr)); + return _mm512_castsi128_si512(input_128); + } else { + __mmask64 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi8(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + if (count == 16) { + // Fast path if only store element number of 16 + _mm_storeu_si128( + reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values)); + } else { + __mmask64 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi8(ptr, mask, values); + } + } + } + + Vectorized abs() const { + return _mm512_andnot_si512(_mm512_set1_epi8(0x80), values); + } + + Vectorized inline operator==(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator!=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator>(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator>=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator<(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator<=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } +}; + +template <> +class Vectorized : public Vectorizedf8 { + public: + using Vectorizedf8::Vectorizedf8; + + using value_type = Float8_e4m3fn; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template < + typename T, + typename Op, + std::enable_if_t< + std::is_same_v || + std::is_same_v, + int> = 0> +static inline Vectorized binary_fp8_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m512 a0, a1, a2, a3; + __m512 b0, b1, b2, b3; + __m512 o0, o1, o2, o3; + if constexpr (std::is_same_v) { + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 0), a0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 0), b0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 1), a1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 1), b1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 2), a2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 2), b2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 3), a3); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 3), b3); + } else { + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 0), a0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 0), b0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 1), a1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 1), b1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 2), a2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 2), b2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 3), a3); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 3), b3); + } + o0 = op(a0, b0); + o1 = op(a1, b1); + o2 = op(a2, b2); + o3 = op(a3, b3); + + __m128i o128_0, o128_1, o128_2, o128_3; + if constexpr (std::is_same_v) { + o128_0 = cvtfp32_fp8e4m3(o0); + o128_1 = cvtfp32_fp8e4m3(o1); + o128_2 = cvtfp32_fp8e4m3(o2); + o128_3 = cvtfp32_fp8e4m3(o3); + } else { + o128_0 = cvtfp32_fp8e5m2(o0); + o128_1 = cvtfp32_fp8e5m2(o1); + o128_2 = cvtfp32_fp8e5m2(o2); + o128_3 = cvtfp32_fp8e5m2(o3); + } + + __m512i result = _mm512_setzero_si512(); + result = _mm512_inserti32x4(result, o128_0, 0); + result = _mm512_inserti32x4(result, o128_1, 1); + result = _mm512_inserti32x4(result, o128_2, 2); + result = _mm512_inserti32x4(result, o128_3, 3); + + return result; +} + +// Refer to +// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, +// -, *, /, planed to be deleted in the future and here is just to make compiler +// happy +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} + +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} + +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} + +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +class Vectorized : public Vectorizedf8 { + public: + using Vectorizedf8::Vectorizedf8; + + using value_type = Float8_e5m2; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +// Refer to +// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, +// -, *, /, planed to be deleted in the future and here is just to make compiler +// happy +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} + +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} + +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} + +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h new file mode 100644 index 0000000000000000000000000000000000000000..b1cbdff7f74ba642689f50a26f7a7ede38541cb2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h @@ -0,0 +1,2115 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX512 + +struct Vectorizedi { + protected: + __m512i values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static inline __m512i invert(const __m512i& v) { + const auto ones = _mm512_set1_epi64(-1); + return _mm512_xor_si512(ones, v); + } + + public: + Vectorizedi() {} + Vectorizedi(__m512i v) : values(v) {} + operator __m512i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX512 + +#ifdef CPU_CAPABILITY_AVX512 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int64_t v) { + values = _mm512_set1_epi64(v); + } + Vectorized( + int64_t val1, + int64_t val2, + int64_t val3, + int64_t val4, + int64_t val5, + int64_t val6, + int64_t val7, + int64_t val8) { + values = _mm512_setr_epi64(val1, val2, val3, val4, val5, val6, val7, val8); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi64(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mask_ = _mm512_cmp_epi64_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi64(mask_, a.values, b.values); + } + template + static Vectorized arange( + int64_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask8 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi64(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask8 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi64(ptr, mask, values); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto is_larger_mask = _mm512_cmpgt_epi64_mask(zero_vector, values); + auto is_larger = + _mm512_mask_set1_epi64(zero_vector, is_larger_mask, 0xFFFFFFFFFFFFFFFF); + auto inverse = _mm512_xor_si512(values, is_larger); + return _mm512_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi64(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +class Vectorized : public Vectorizedi { + private: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; + + public: + using value_type = int32_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { + values = _mm512_set1_epi32(v); + } + Vectorized( + int32_t val1, + int32_t val2, + int32_t val3, + int32_t val4, + int32_t val5, + int32_t val6, + int32_t val7, + int32_t val8, + int32_t val9, + int32_t val10, + int32_t val11, + int32_t val12, + int32_t val13, + int32_t val14, + int32_t val15, + int32_t val16) { + values = _mm512_setr_epi32( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi32(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi32(0xFFFFFFFF); + auto mask_ = _mm512_cmp_epi32_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi32(mask_, a.values, b.values); + } + template + static Vectorized arange( + int32_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask16 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi32(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask16 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi32(ptr, mask, values); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + int32_t reduce_add() const { + return _mm512_reduce_add_epi32(values); + } + int32_t reduce_max() const { + return _mm512_reduce_max_epi32(values); + } + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_vec = + _mm512_loadu_si512(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_ps(input_vec); + _mm512_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, double* dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_256_vec = + _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_pd(input_256_vec); + _mm512_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = int16_t; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { + values = _mm512_set1_epi16(v); + } + Vectorized( + int16_t val1, + int16_t val2, + int16_t val3, + int16_t val4, + int16_t val5, + int16_t val6, + int16_t val7, + int16_t val8, + int16_t val9, + int16_t val10, + int16_t val11, + int16_t val12, + int16_t val13, + int16_t val14, + int16_t val15, + int16_t val16, + int16_t val17, + int16_t val18, + int16_t val19, + int16_t val20, + int16_t val21, + int16_t val22, + int16_t val23, + int16_t val24, + int16_t val25, + int16_t val26, + int16_t val27, + int16_t val28, + int16_t val29, + int16_t val30, + int16_t val31, + int16_t val32) { + values = _mm512_set_epi16( + val32, + val31, + val30, + val29, + val28, + val27, + val26, + val25, + val24, + val23, + val22, + val21, + val20, + val19, + val18, + val17, + val16, + val15, + val14, + val13, + val12, + val11, + val10, + val9, + val8, + val7, + val6, + val5, + val4, + val3, + val2, + val1); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask32 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi16(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi16(ptr, mask, values); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + + protected: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; + + public: + using value_type = T; + static constexpr int size() { + return 64; + } + using Vectorizedi::Vectorizedi; + Vectorized8() {} + Vectorized8(T v) { + values = _mm512_set1_epi8(v); + } + Vectorized8( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32, + T val33, + T val34, + T val35, + T val36, + T val37, + T val38, + T val39, + T val40, + T val41, + T val42, + T val43, + T val44, + T val45, + T val46, + T val47, + T val48, + T val49, + T val50, + T val51, + T val52, + T val53, + T val54, + T val55, + T val56, + T val57, + T val58, + T val59, + T val60, + T val61, + T val62, + T val63, + T val64) { + values = _mm512_set_epi8( + val64, + val63, + val62, + val61, + val60, + val59, + val58, + val57, + val56, + val55, + val54, + val53, + val52, + val51, + val50, + val49, + val48, + val47, + val46, + val45, + val44, + val43, + val42, + val41, + val40, + val39, + val38, + val37, + val36, + val35, + val34, + val33, + val32, + val31, + val30, + val29, + val28, + val27, + val26, + val25, + val24, + val23, + val22, + val21, + val20, + val19, + val18, + val17, + val16, + val15, + val14, + val13, + val12, + val11, + val10, + val9, + val8, + val7, + val6, + val5, + val4, + val3, + val2, + val1); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi8(mask, a.values, b.values); + } + template + static Vectorized arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step, + base + 32 * step, + base + 33 * step, + base + 34 * step, + base + 35 * step, + base + 36 * step, + base + 37 * step, + base + 38 * step, + base + 39 * step, + base + 40 * step, + base + 41 * step, + base + 42 * step, + base + 43 * step, + base + 44 * step, + base + 45 * step, + base + 46 * step, + base + 47 * step, + base + 48 * step, + base + 49 * step, + base + 50 * step, + base + 51 * step, + base + 52 * step, + base + 53 * step, + base + 54 * step, + base + 55 * step, + base + 56 * step, + base + 57 * step, + base + 58 * step, + base + 59 * step, + base + 60 * step, + base + 61 * step, + base + 62 * step, + base + 63 * step); + } + static Vectorized set(Vectorized a, Vectorized b, T count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + case 32: + return blend<0xFFFFFFFF>(a, b); + case 33: + return blend<0x1FFFFFFFF>(a, b); + case 34: + return blend<0x3FFFFFFFF>(a, b); + case 35: + return blend<0x7FFFFFFFF>(a, b); + case 36: + return blend<0xFFFFFFFFF>(a, b); + case 37: + return blend<0x1FFFFFFFFF>(a, b); + case 38: + return blend<0x3FFFFFFFFF>(a, b); + case 39: + return blend<0x7FFFFFFFFF>(a, b); + case 40: + return blend<0xFFFFFFFFFF>(a, b); + case 41: + return blend<0x1FFFFFFFFFF>(a, b); + case 42: + return blend<0x3FFFFFFFFFF>(a, b); + case 43: + return blend<0x7FFFFFFFFFF>(a, b); + case 44: + return blend<0xFFFFFFFFFFF>(a, b); + case 45: + return blend<0x1FFFFFFFFFFF>(a, b); + case 46: + return blend<0x3FFFFFFFFFFF>(a, b); + case 47: + return blend<0x7FFFFFFFFFFF>(a, b); + case 48: + return blend<0xFFFFFFFFFFFF>(a, b); + case 49: + return blend<0x1FFFFFFFFFFFF>(a, b); + case 50: + return blend<0x3FFFFFFFFFFFF>(a, b); + case 51: + return blend<0x7FFFFFFFFFFFF>(a, b); + case 52: + return blend<0xFFFFFFFFFFFFF>(a, b); + case 53: + return blend<0x1FFFFFFFFFFFFF>(a, b); + case 54: + return blend<0x3FFFFFFFFFFFFF>(a, b); + case 55: + return blend<0x7FFFFFFFFFFFFF>(a, b); + case 56: + return blend<0xFFFFFFFFFFFFFF>(a, b); + case 57: + return blend<0x1FFFFFFFFFFFFFF>(a, b); + case 58: + return blend<0x3FFFFFFFFFFFFFF>(a, b); + case 59: + return blend<0x7FFFFFFFFFFFFFF>(a, b); + case 60: + return blend<0xFFFFFFFFFFFFFFF>(a, b); + case 61: + return blend<0x1FFFFFFFFFFFFFFF>(a, b); + case 62: + return blend<0x3FFFFFFFFFFFFFFF>(a, b); + case 63: + return blend<0x7FFFFFFFFFFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu_one_fourth(const void* ptr) { + // Fast path if only load element number of 16. + // Note: We didn't merge it as fast path of loadu(const void* ptr, T count), + // Because loadu(const void* ptr, T count) requires zero initialization for + // upper 384 bits. However, by using _mm512_castsi128_si512, the upper 384 + // bits of the result are undefined. + // TODO We can use _mm512_zextsi128_si512 in the furture, + // since gcc 9.3 doesn't support it now. + __m128i input_128 = _mm_loadu_si128(reinterpret_cast(ptr)); + return _mm512_castsi128_si512(input_128); + } + static Vectorized loadu(const void* ptr, T count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else if (count == 16) { + // Fast path if only load element number of 16 + return loadu_one_fourth(ptr); + } else { + __mmask64 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi8(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + if (count == 16) { + // Fast path if only store element number of 16 + _mm_storeu_si128( + reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values)); + } else { + __mmask64 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi8(ptr, mask, values); + } + } + } + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + + Vectorized neg() const; + + Vectorized abs() const { + return _mm512_abs_epi8(values); + } + + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epu8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi64(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi16(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi64(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_512( + const Vectorized& a, + const Vectorized& b, + Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying int8_t +#ifndef CPU_CAPABILITY_AVX512 + return int_elementwise_binary_512(a, b, std::multiplies()); +#else + __m512i mask00FF = _mm512_set1_epi16(0x00FF); + __m512i a_lo = _mm512_srai_epi16(_mm512_slli_epi16(a, 8), 8); + __m512i b_lo = _mm512_srai_epi16(_mm512_slli_epi16(b, 8), 8); + __m512i a_hi = _mm512_srai_epi16(a, 8); + __m512i b_hi = _mm512_srai_epi16(b, 8); + __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF); + __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8); + __m512i res = _mm512_or_si512(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t +#ifndef CPU_CAPABILITY_AVX512 + return int_elementwise_binary_512(a, b, std::multiplies()); +#else + __m512i mask00FF = _mm512_set1_epi16(0x00FF); + __m512i a_lo = _mm512_and_si512(a, mask00FF); + __m512i b_lo = _mm512_and_si512(b, mask00FF); + __m512i a_hi = _mm512_srli_epi16(a, 8); + __m512i b_hi = _mm512_srli_epi16(b, 8); + __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF); + __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8); + __m512i res = _mm512_or_si512(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi64(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi32(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi16(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi8(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epu8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi64(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi32(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi16(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epu8(a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi64(max_val, _mm512_max_epi64(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi32(max_val, _mm512_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi16(max_val, _mm512_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi8(max_val, _mm512_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epu8(max_val, _mm512_max_epu8(a, min_val)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi64(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epu8(max_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi64(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi8(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epu8(min_val, a); +} + +template +std::enable_if_t< + !(std::is_same_v || std::is_same_v), + Vectorized< + int32_t>> inline convert_to_int32(const T* ptr, int count = Vectorized::size()) { + return Vectorized::loadu(ptr, count); +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const int8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm512_cvtepi8_epi32( + _mm_loadu_si128(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm512_cvtepi8_epi32(_mm512_castsi512_si128(a)); + } +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const uint8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm512_cvtepu8_epi32(_mm512_castsi512_si128(a)); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} + +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm512_xor_si512(a, _mm512_set1_epi32(-1)); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template < + bool left_shift, + typename T, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + int> = 0> +Vectorized inline shift_512_8( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 512 bits as an + // array of 8-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m512i ctl_0_1 = _mm512_set_epi8( + 62, + 0x80, + 60, + 0x80, + 58, + 0x80, + 56, + 0x80, + 54, + 0x80, + 52, + 0x80, + 50, + 0x80, + 48, + 0x80, + 46, + 0x80, + 44, + 0x80, + 42, + 0x80, + 40, + 0x80, + 38, + 0x80, + 36, + 0x80, + 34, + 0x80, + 32, + 0x80, + 30, + 0x80, + 28, + 0x80, + 26, + 0x80, + 24, + 0x80, + 22, + 0x80, + 20, + 0x80, + 18, + 0x80, + 16, + 0x80, + 14, + 0x80, + 12, + 0x80, + 10, + 0x80, + 8, + 0x80, + 6, + 0x80, + 4, + 0x80, + 2, + 0x80, + 0, + 0x80); + __m512i ctl_1_0 = _mm512_set_epi8( + 0x80, + 63, + 0x80, + 61, + 0x80, + 59, + 0x80, + 57, + 0x80, + 55, + 0x80, + 53, + 0x80, + 51, + 0x80, + 49, + 0x80, + 47, + 0x80, + 45, + 0x80, + 43, + 0x80, + 41, + 0x80, + 39, + 0x80, + 37, + 0x80, + 35, + 0x80, + 33, + 0x80, + 31, + 0x80, + 29, + 0x80, + 27, + 0x80, + 25, + 0x80, + 23, + 0x80, + 21, + 0x80, + 19, + 0x80, + 17, + 0x80, + 15, + 0x80, + 13, + 0x80, + 11, + 0x80, + 9, + 0x80, + 7, + 0x80, + 5, + 0x80, + 3, + 0x80, + 1); + + // Masks for bitwise and operation, treating 512 bits as an array of + // 8-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m512i keep_0 = _mm512_set1_epi16(0xFF); + __m512i keep_1 = _mm512_set1_epi16(0xFF00); + + // Take each 8-bit element with idx%2==0 from input array to be + // shifted and extend it to 16 bits so that 0s are added to the + // right. Then, perform shifting on this 16-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 16 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1); + __m512i b0 = _mm512_and_si512(b, keep_0); + __m512i c0; + if (left_shift) + c0 = _mm512_sllv_epi16(a0, b0); + else if constexpr (std::is_same_v) + c0 = _mm512_srav_epi16(a0, b0); + else + c0 = _mm512_srlv_epi16(a0, b0); + c0 = _mm512_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m512i a1 = _mm512_and_si512(a, keep_1); + __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0); + __m512i c1; + if (left_shift) + c1 = _mm512_sllv_epi16(a1, b1); + else if constexpr (std::is_same_v) + c1 = _mm512_srav_epi16(a1, b1); + else + c1 = _mm512_srlv_epi16(a1, b1); + c1 = _mm512_and_si512(c1, keep_1); + + // Merge partial results into the final result. + __m512i c = _mm512_or_si512(c0, c1); + + return c; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi16(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi64(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi16(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..1c8baea16b4873567ff5bb0c0235d3d70d6137be --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h @@ -0,0 +1,390 @@ +#pragma once + +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + dst_n, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n && dst_n >= 1) && + (std::is_same_v || std::is_same_v)>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; + VectorizedN tmp_vec; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int_mask = VecMask(tmp_vec).template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + 1, + mask_t, + 1, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto int_mask = vec_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + auto zero = _mm_set1_epi8(0); + auto temp = _mm_mask_loadu_epi8(zero, mmask, ptr); + return Vectorized( + _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0)); + } +}; + +template +struct VecMaskLoad< + data_t, + 2, + mask_t, + 1, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + at::vec::Vectorized zero_vec(0); + auto int_mask = vec_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + at::vec::VectorizedN result; + if constexpr (std::is_same_v) { + result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } else { + result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castps_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castpd_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +inline bool VecMask::all_zero() const { + __mmask16 mask = _mm512_test_epi32_mask(mask_[0], mask_[0]); + return mask == 0; +} + +template <> +inline bool VecMask::is_masked(int i) const { + return _mm512_movepi32_mask(mask_[0]) & (1 << i); +} + +template <> +inline bool VecMask::all_masked() const { + __mmask16 mask = _mm512_movepi32_mask(mask_[0]); + return mask == 0xffff; +} + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = + all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 8) { + return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + +#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ + T, N, return_type, method, args_def, args) \ + template <> \ + inline return_type VecMask::method args_def const { \ + return cast().method args; \ + } + +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) + +#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..09789acaae16eab62b45eb6e130902e64ca930f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h @@ -0,0 +1,1528 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +#ifdef _MSC_VER +__declspec(align(64)) struct Vectorizedqi { + protected: + __m512i vals; +#else +struct Vectorizedqi { + protected: + __m512i vals __attribute__((aligned(64))); +#endif + + public: + Vectorizedqi() {} + Vectorizedqi(__m512i v) : vals(v) {} + operator __m512i() const { + return vals; + } +}; + +template +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + T min_val, + T max_val); + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first [[maybe_unused]], + __m512i second [[maybe_unused]], + int32_t min_val [[maybe_unused]], + int32_t max_val [[maybe_unused]]) { + // This function is for linkage only, will not be used + TORCH_CHECK(false, "pack_saturate_and_clamp is not supported"); + return __m512i{}; +} + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + int8_t min_val, + int8_t max_val) { + __m512i packed_and_sat = _mm512_packs_epi16(first, second); + return _mm512_max_epi8( + _mm512_set1_epi8(min_val), + _mm512_min_epi8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + uint8_t min_val, + uint8_t max_val) { + __m512i packed_and_sat = _mm512_packus_epi16(first, second); + return _mm512_max_epu8( + _mm512_set1_epi8(min_val), + _mm512_min_epu8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(at::vec::Vectorized src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 16*8 bits + __m128i input_128 = _mm512_castsi512_si128(src); + // Convert from 16*uint8/int8 to 16*int32 + __m512i input_512_extended; + if constexpr (std::is_same_v) + input_512_extended = _mm512_cvtepu8_epi32(input_128); + else + input_512_extended = _mm512_cvtepi8_epi32(input_128); + // Convert from 16*int32 to 16*float32 + return _mm512_cvtepi32_ps(input_512_extended); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(at::vec::Vectorized src) { + // Convert from float32 to int32 with truncation + __m512i x_values_int32 = _mm512_cvttps_epi32(src); + + // Convert from int32 to int16 using signed saturation + __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32); + + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to uint8/int8 using unsigned saturation + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); +} + +template +__FORCE_INLINE void QuantizeAvx512( + const float* src, + T* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 16; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m512i min_v = _mm512_set1_epi32(min_val); + const __m512i max_v = _mm512_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m512 inverse_scale_v = _mm512_set1_ps(inverse_scale); + // clang-format off + static const __m512i shuffle_mask_v = _mm512_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + __m512i permute_mask_l8_v = _mm512_set_epi32( + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x0c, + 0x08, + 0x04, + 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm512_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // y + __m512 y_vals = _mm512_load_ps(src + i + VLEN); + __m512 y_transformed_v = _mm512_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm512_min_ps(y_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // z + __m512 z_vals = _mm512_load_ps(src + i + 2 * VLEN); + __m512 z_transformed_v = _mm512_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm512_min_ps(z_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // w + __m512 w_vals = _mm512_load_ps(src + i + 3 * VLEN); + __m512 w_transformed_v = _mm512_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm512_min_ps(w_transformed_v, _mm512_set1_ps(int32_float_max_val)); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_transformed_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_transformed_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + y_rounded_v = _mm512_add_epi32(y_rounded_v, _mm512_set1_epi32(zero_point)); + z_rounded_v = _mm512_add_epi32(z_rounded_v, _mm512_set1_epi32(zero_point)); + w_rounded_v = _mm512_add_epi32(w_rounded_v, _mm512_set1_epi32(zero_point)); + + __m512i xy_packed_v = _mm512_packs_epi32(x_rounded_v, y_rounded_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_rounded_v, w_rounded_v); + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX512 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + __m512i x_clipped_v = + _mm512_max_epi32(min_v, _mm512_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm512_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm512_permutexvar_epi32(permute_mask_l8_v, x_clipped_v); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(dst + i), + _mm512_castsi512_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + std::nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type size() { + return 16; + } + + static constexpr int float_num_vecs() { + return 1; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + Vectorized retval; + auto rhs_data = (__m512)rhs[0]; + at::native::quantize_vec( + scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 16); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi32( + _mm512_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm512_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + + __m512 scaled = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier_v); + __m512i rounded = _mm512_cvtps_epi32(scaled); + return _mm512_add_epi32(rounded, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m512i RequantizeAvx512( + const std::array, 4>& inp, + __m512 multiplier, + __m512i zp) { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + __m512 x_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier); + __m512 y_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[1]), multiplier); + __m512 z_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[2]), multiplier); + __m512 w_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[3]), multiplier); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m512i x_v = _mm512_add_epi32(x_rounded_v, zp); + __m512i y_v = _mm512_add_epi32(y_rounded_v, zp); + __m512i z_v = _mm512_add_epi32(z_rounded_v, zp); + __m512i w_v = _mm512_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m512i xy_packed_v = _mm512_packs_epi32(x_v, y_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_v, w_v); + + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 x8-11 y8-11 z8-11 w8-11 + * x12-15 y12-15 z12-15 w12-15 + */ + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + return xyzw_clamped_v; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + + // This is added to avoid error: definition of implicit copy assignment + // operator for 'Vectorized' is deprecated because it has a + // user-declared copy constructor [-Werror,-Wdeprecated-copy] + Vectorized& operator=(const Vectorized&) = default; + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + private: + __m512i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm512_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_neg_zp_premul) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi8(_mm512_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512i int32_val0 = cvtepi8_epi32(int_val0); + __m512i int32_val1 = cvtepi8_epi32(int_val1); + __m512i int32_val2 = cvtepi8_epi32(int_val2); + __m512i int32_val3 = cvtepi8_epi32(int_val3); + +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]); +#else + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); +#endif + + __m512i int32_b0 = cvtepi8_epi32(int_b0); + __m512i int32_b1 = cvtepi8_epi32(int_b1); + __m512i int32_b2 = cvtepi8_epi32(int_b2); + __m512i int32_b3 = cvtepi8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + + // This is added to avoid error: definition of implicit copy assignment + // operator for 'Vectorized' is deprecated because it has a + // user-declared copy constructor [-Werror,-Wdeprecated-copy] + Vectorized& operator=(const Vectorized&) = default; + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + private: + __m512i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm512_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epu8(_mm512_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512i int32_val0 = cvtepu8_epi32(int_val0); + __m512i int32_val1 = cvtepu8_epi32(int_val1); + __m512i int32_val2 = cvtepu8_epi32(int_val2); + __m512i int32_val3 = cvtepu8_epi32(int_val3); + +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]); +#else + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); +#endif + + __m512i int32_b0 = cvtepu8_epi32(int_b0); + __m512i int32_b1 = cvtepu8_epi32(int_b1); + __m512i int32_b2 = cvtepu8_epi32(int_b2); + __m512i int32_b3 = cvtepu8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#else + +// NOTE: These are low-performance implementations that we fall back on. + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / 8; + } + + static constexpr int int_num_vecs() { + return size() / 8; + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (const auto i : c10::irange(size())) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul [[maybe_unused]]) const { + float_vec_return_type rv; + for (const auto i : c10::irange(float_num_vecs())) { + float tmp_vals[16]; + for (const auto j : c10::irange(16)) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], zero_point[j], T(vals[16 * i + j])); + } + rv[i] = Vectorized( + tmp_vals[0], + tmp_vals[1], + tmp_vals[2], + tmp_vals[3], + tmp_vals[4], + tmp_vals[5], + tmp_vals[6], + tmp_vals[7], + tmp_vals[8], + tmp_vals[9], + tmp_vals[10], + tmp_vals[11], + tmp_vals[12], + tmp_vals[13], + tmp_vals[14], + tmp_vals[15]); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (const auto i : c10::irange(size())) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = + std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(MSVC) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h new file mode 100644 index 0000000000000000000000000000000000000000..53cdbcd5ba8dfb3e095cef2397c1b466527bd9d9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h @@ -0,0 +1,1512 @@ +#pragma once +#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \ + defined(__ARM_FEATURE_SVE) +// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161 +#pragma GCC optimize("no-tree-vectorize") +#endif + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] +// +// Note [Do not compile initializers with AVX] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// If you define a static initializer in this file, the initialization will use +// AVX instructions because these object files are compiled with AVX enabled. +// We need to avoid non-trivial global data in these architecture specific files +// because there's no way to guard the global initializers with CPU capability +// detection. +// +// See https://github.com/pytorch/pytorch/issues/37577 for an instance +// of this bug in the past. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__GNUC__) +#define __FORCE_INLINE __attribute__((always_inline)) inline +#elif defined(_MSC_VER) +#define __FORCE_INLINE __forceinline +#endif + +#if defined(_MSC_FULL_VER) +/* +https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170 +Use _MSC_FULL_VER to identify current compiler is msvc, +Windows llvm will not have this definition. +*/ +#define __msvc_cl__ +#endif + +// These macros helped us unify vec_base.h +#ifdef CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(64))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(64)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 64 +#define int_vector __m512i +#elif defined(__aarch64__) && \ + !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 +// SVE code expects 256-vectors; leave that set for SVE? +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(16))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(16)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 16 +#else // CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(32))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(32)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 32 +#define int_vector __m256i +#endif // CPU_CAPABILITY_AVX512 + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +// at::Half and at::BFloat16 should be treated as floating point +template +struct is_floating_point + : std::integral_constant< + bool, + std::is_floating_point_v || std::is_same_v || + std::is_same_v> {}; + +template +constexpr bool is_floating_point_v = is_floating_point::value; + +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; + +template +struct is_8bit_integer + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> { +}; + +template +constexpr bool is_8bit_integer_v = is_8bit_integer::value; + +template +struct int_of_size; + +#define DEFINE_INT_OF_SIZE(int_t) \ + template <> \ + struct int_of_size { \ + using type = int_t; \ + } + +DEFINE_INT_OF_SIZE(int64_t); +DEFINE_INT_OF_SIZE(int32_t); +DEFINE_INT_OF_SIZE(int16_t); +DEFINE_INT_OF_SIZE(int8_t); + +#undef DEFINE_INT_OF_SIZE + +template +using int_same_size_t = typename int_of_size::type; + +/** + * Detect at compile time whether Vectorized has an explicit + * specialization for T. (You are required to specialize this type + * whenever you specialize Vectorized). Useful for generic algorithms + * to decide whether to rely on a specialization being fast. For + * example, they might choose to handle reduced-precision floating + * point types directly if they're supported, or convert through float + * if not. + */ +#if defined(__s390x__) +template +#else +template +#endif +struct is_vec_specialized_for : std::bool_constant { +}; + +template +constexpr bool is_vec_specialized_for_v = is_vec_specialized_for::value; + +// NOTE: If you specialize Vectorized on a type, you must define all +// operations! You must also specialize is_vec_specialized_for for +// that type. + +// emulates Vectorized types +#if defined(__s390x__) +template +#else +template +#endif +struct Vectorized { + private: + __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; + + public: + using value_type = T; + using size_type = int; + + static constexpr size_type kSize = VECTOR_WIDTH / sizeof(T); + static constexpr size_type size() { + return kSize; + } + Vectorized() : values{static_cast(0)} {} + Vectorized(T val) { + for (int i = 0; i != size(); i++) { + values[i] = val; + } + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) : values{vals...} {} + Vectorized(const T (&arr)[kSize]) { + std::memcpy(values, arr, sizeof(values)); + } + // This also implies const T& operator[](int idx) const + inline operator const T*() const { + return values; + } + // This also implies T& operator[](int idx) + inline operator T*() { + return values; + } + // Return the values as char* for type punning + auto as_bytes() const -> const char* { + return reinterpret_cast(values); + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + int64_t mask = mask_; + Vectorized vector; + for (const auto i : c10::irange(size())) { + if (mask & 0x01) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + mask = mask >> 1; + } + return vector; + } +// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 +#if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE) + static Vectorized __attribute__((optimize("-fno-tree-loop-vectorize"))) + blendv( + const Vectorized& a, +#else + static Vectorized blendv( + const Vectorized& a, +#endif + const Vectorized& b, + const Vectorized& mask) { + Vectorized vector; + int_same_size_t buffer[size()]; + mask.store(buffer); +#if defined(__clang__) && __ARM_FEATURE_SVE +#pragma clang loop vectorize(disable) +#endif + for (const auto i : c10::irange(size())) { + if (buffer[i] & 0x01) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + } + return vector; + } + template // step sometimes requires a higher precision type + // (e.g., T=int, step_t=double) + static Vectorized arange( + T base = static_cast(0), + step_t step = static_cast(1)) { + Vectorized vector; + for (const auto i : c10::irange(size())) { + vector.values[i] = base + i * step; + } + return vector; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + Vectorized vector; + for (const auto i : c10::irange(size())) { + if (i < count) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + } + return vector; + } + static Vectorized loadu(const void* ptr) { + Vectorized vector; + std::memcpy(vector.values, ptr, VECTOR_WIDTH); + return vector; + } + static Vectorized loadu(const void* ptr, int64_t count) { + Vectorized vector; + std::memcpy(vector.values, ptr, count * sizeof(T)); + return vector; + } + static Vectorized loadu_one_fourth(const void* ptr) { + static_assert( + std::is_same_v || std::is_same_v, + "For byte types only"); + return Vectorized::loadu(ptr, 8); + } + + void store(void* ptr, int count = size()) const { + std::memcpy(ptr, values, count * sizeof(T)); + } + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (values[i] == static_cast(0)) { + mask |= (1 << i); + } + } + return mask; + } + Vectorized isnan() const { + Vectorized vector; + for (int64_t i = 0; i != size(); i++) { + if (_isnan(values[i])) { + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); + } else { + std::memset(static_cast(vector.values + i), 0, sizeof(T)); + } + } + return vector; + } + bool has_inf_nan() const { + for (int64_t i = 0; i != size(); i++) { + if (_isnan(values[i]) || _isinf(values[i])) { + return true; + } + } + return false; + } +// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows +// Arm64 +// See +// https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 +#if defined(_WIN32) && defined(__aarch64__) && \ + ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) + Vectorized map(T (*const f)(T)) const { + Vectorized ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = f(values[i]); + if (++i < size()) + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(T)) const { + T ret = 0; + for (int64_t i = 0; i < size(); i++) { + ret = f(ret, values[i]); + if (++i < size()) + ret = f(ret, values[i]); + } + return ret; + } +#else + Vectorized map(T (*const f)(T)) const { + Vectorized ret; + for (int64_t i = 0; i != size(); i++) { + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(T)) const { + T ret = 0; + for (int64_t i = 0; i != size(); i++) { + ret = f(ret, values[i]); + } + return ret; + } +#endif + Vectorized map(T (*const f)(const T&)) const { + Vectorized ret; + for (int64_t i = 0; i != size(); i++) { + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(const T&)) const { + T ret = 0; + for (int64_t i = 0; i != size(); i++) { + ret = f(ret, values[i]); + } + return ret; + } + template < + typename other_t_abs = T, + typename std::enable_if_t< + !is_floating_point_v && + !c10::is_complex::value, + int> = 0> + Vectorized abs() const { + // other_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_abs must be T"); + return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); + } + template < + typename float_t_abs = T, + typename std::enable_if_t, int> = 0> + Vectorized abs() const { + // float_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "float_t_abs must be T"); + // Specifically deal with floating-point because the generic code above + // won't handle -0.0 (which should result in 0.0) properly. + return map([](T x) -> T { return std::abs(x); }); + } + template < + typename complex_t_abs = T, + typename std::enable_if_t::value, int> = 0> + Vectorized abs() const { + // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "complex_t_abs must be T"); + // Specifically map() does not perform the type conversion needed by abs. + return map([](T x) { return static_cast(std::abs(x)); }); + } + + template < + typename other_t_sgn = T, + typename std::enable_if_t::value, int> = 0> + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + template < + typename other_t_angle = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized angle() const { + // other_t_angle is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_angle must be T"); + return map(at::native::angle_impl); // compiler is unable to resolve the + // overload without + } + template < + typename complex_t_angle = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized angle() const { + // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_angle must be T"); + return map([](T x) { return static_cast(std::arg(x)); }); + } + template < + typename other_t_real = T, + typename std::enable_if_t::value, int> = 0> + Vectorized real() const { + // other_t_real is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_real must be T"); + return *this; + } + template < + typename complex_t_real = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized real() const { + // complex_t_real is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_real must be T"); + return map([](T x) { return static_cast(x.real()); }); + } + template < + typename other_t_imag = T, + typename std::enable_if_t::value, int> = 0> + Vectorized imag() const { + // other_t_imag is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_imag must be T"); + return Vectorized(0); + } + template < + typename complex_t_imag = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized imag() const { + // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_imag must be T"); + return map([](T x) { return static_cast(x.imag()); }); + } + template < + typename other_t_conj = T, + typename std::enable_if_t::value, int> = 0> + Vectorized conj() const { + // other_t_conj is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_conj must be T"); + return *this; + } + template < + typename complex_t_conj = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized conj() const { + // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_conj must be T"); + return map([](T x) { return static_cast(std::conj(x)); }); + } + Vectorized acos() const { + return map(std::acos); + } + Vectorized acosh() const { + return map(std::acosh); + } + Vectorized asin() const { + return map(std::asin); + } + Vectorized asinh() const { + return map(std::asinh); + } + Vectorized atan() const { + return map(std::atan); + } + Vectorized atanh() const { + return map(std::atanh); + } + Vectorized atan2(const Vectorized& exp) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::atan2(values[i], exp[i]); + } + return ret; + } + template < + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized copysign(const Vectorized& sign) const { + Vectorized ret; + for (size_type i = 0; i < size(); i++) { + ret[i] = c10::copysign(values[i], sign[i]); + } + return ret; + } + Vectorized erf() const { + return map(std::erf); + } + Vectorized erfc() const { + return map(std::erfc); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + Vectorized exp_u20() const { + return map(std::exp); + } + Vectorized frac() const { + return *this - this->trunc(); + } + template < + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized fmod(const Vectorized& q) const { + // U is for SFINAE purposes only. Make sure it is not changed. + static_assert(std::is_same_v, "U must be T"); + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::fmod(values[i], q[i]); + } + return ret; + } + Vectorized log() const { + return map(std::log); + } + Vectorized log10() const { + return map(std::log10); + } + Vectorized log1p() const { + return map(std::log1p); + } + template < + typename other_t_log2 = T, + typename std::enable_if_t::value, int> = 0> + Vectorized log2() const { + // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_log2 must be T"); + return map(std::log2); + } + template < + typename complex_t_log2 = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized log2() const { + // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_log2 must be T"); + const T log_2 = T(std::log(2.0)); + return Vectorized(map(std::log)) / Vectorized(log_2); + } + Vectorized ceil() const { + return map(at::native::ceil_impl); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + Vectorized floor() const { + return map(at::native::floor_impl); + } + Vectorized hypot(const Vectorized& b) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::hypot(values[i], b[i]); + } + return ret; + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = calc_igamma(values[i], x[i]); + } + return ret; + } + Vectorized igammac(const Vectorized& x) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = calc_igammac(values[i], x[i]); + } + return ret; + } + Vectorized neg() const { + // NB: the trailing return type is needed because we need to coerce the + // return value back to T in the case of unary operator- incuring a + // promotion + return map([](T x) -> T { return -x; }); + } + Vectorized nextafter(const Vectorized& b) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::nextafter(values[i], b[i]); + } + return ret; + } + Vectorized round() const { + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. + return map(at::native::round_impl); + } + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized trunc() const { + return map(at::native::trunc_impl); + } + Vectorized lgamma() const { + return map(std::lgamma); + } + Vectorized sqrt() const { + return map(std::sqrt); + } + Vectorized reciprocal() const { + return map([](T x) { return (T)(1) / x; }); + } + Vectorized rsqrt() const { + return map([](T x) { return (T)1 / std::sqrt(x); }); + } + Vectorized pow(const Vectorized& exp) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::pow(values[i], exp[i]); + } + return ret; + } + T reduce_add() const { + return reduce([](T x, T y) -> T { return x + y; }); + } + T reduce_max() const { + return reduce(std::max); + } + + private: + template + inline Vectorized binary_pred(const Vectorized& other, Op op) const { + // All bits are set to 1 if the pred is true, otherwise 0. + Vectorized vector; + for (int64_t i = 0; i != size(); i++) { + if (op(values[i], other.values[i])) { + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); + } else { + std::memset(static_cast(vector.values + i), 0, sizeof(T)); + } + } + return vector; + } + + public: + Vectorized operator==(const Vectorized& other) const { + return binary_pred(other, std::equal_to()); + } + Vectorized operator!=(const Vectorized& other) const { + return binary_pred(other, std::not_equal_to()); + } + Vectorized operator>=(const Vectorized& other) const { + return binary_pred(other, std::greater_equal()); + } + Vectorized operator<=(const Vectorized& other) const { + return binary_pred(other, std::less_equal()); + } + Vectorized operator>(const Vectorized& other) const { + return binary_pred(other, std::greater()); + } + Vectorized operator<(const Vectorized& other) const { + return binary_pred(other, std::less()); + } + + private: + template + inline Vectorized binary_pred_bool(const Vectorized& other, Op op) + const { + // 1 if the pred is true, otherwise 0. + Vectorized vector; + for (int i = 0; i != size(); ++i) { + vector[i] = static_cast(op(values[i], other.values[i])); + } + return vector; + } + + public: + Vectorized eq(const Vectorized& other) const { + return binary_pred_bool(other, std::equal_to()); + } + Vectorized ne(const Vectorized& other) const { + return binary_pred_bool(other, std::not_equal_to()); + } + Vectorized gt(const Vectorized& other) const { + return binary_pred_bool(other, std::greater()); + } + Vectorized ge(const Vectorized& other) const { + return binary_pred_bool(other, std::greater_equal()); + } + Vectorized lt(const Vectorized& other) const { + return binary_pred_bool(other, std::less()); + } + Vectorized le(const Vectorized& other) const { + return binary_pred_bool(other, std::less_equal()); + } +}; + +template +Vectorized inline operator-(const Vectorized& a) { + return a.neg(); +} + +// There is an implicit conversion that would make this work if +// these operators weren't template functions, but they are template +// functions (and can't be moved to be non-member friends defined in +// the class body as suggested in +// https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255 +// because we have a lot of disparate specializations of +// Vectorized). So, just explicitly make scalars work. +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \ + template \ + Vectorized inline name(const Vectorized& a, T b) { \ + return name(a, Vectorized(b)); \ + } \ + template \ + Vectorized inline name(T a, const Vectorized& b) { \ + return name(Vectorized(a), b); \ + } +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \ + VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op) + +template +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] + b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+) + +template +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] - b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-) + +template +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] * b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*) + +template +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] / b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/) + +template , int> = 0> +Vectorized inline operator%(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { + return a - a / b * b; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%) + +template +Vectorized inline operator||( + const Vectorized& a, + const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] || b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||) + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (a[i] > b[i]) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum) + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (a[i] < b[i]) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_vec, + const Vectorized& max_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); + } + return c; +} + +#define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \ + template \ + Vectorized inline name( \ + const Vectorized& a, const Vectorized& b, T c) { \ + return name(a, b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + const Vectorized& a, T b, const Vectorized& c) { \ + return name(a, Vectorized(b), c); \ + } \ + \ + template \ + Vectorized inline name(const Vectorized& a, T b, T c) { \ + return name(a, Vectorized(b), Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + T a, const Vectorized& b, const Vectorized& c) { \ + return name(Vectorized(a), b, c); \ + } \ + \ + template \ + Vectorized inline name(T a, const Vectorized& b, T c) { \ + return name(Vectorized(a), b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name(T a, T b, const Vectorized& c) { \ + return name(Vectorized(a), Vectorized(b), c); \ + } + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min) + +struct Vectorizedi; + +#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { + int_vector buffer; +#if defined(CPU_CAPABILITY_AVX2) + int_vector a_buffer = + _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm256_load_si256(reinterpret_cast((const T*)b)); +#elif defined(CPU_CAPABILITY_AVX512) + int_vector a_buffer = + _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm512_load_si512(reinterpret_cast((const T*)b)); +#endif + buffer = op(a_buffer, b_buffer); + __at_align__ T results[Vectorized::size()]; + +#if defined(CPU_CAPABILITY_AVX2) + _mm256_store_si256(reinterpret_cast(results), buffer); +#elif defined(CPU_CAPABILITY_AVX512) + _mm512_store_si512(reinterpret_cast(results), buffer); +#endif + return Vectorized::loadu(results); +} + +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); +#endif +} +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); +#endif +} +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); +#endif +} + +#else + +template +auto load(char const* data) -> T { + T ret; + std::memcpy(&ret, data, sizeof(ret)); + return ret; +} + +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { + static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); + __at_align__ intmax_t buffer[element_no]; + static_assert( + VECTOR_WIDTH % sizeof(intmax_t) == 0, + "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); + static_assert( + sizeof(buffer) == sizeof(Vectorized), + "sizeof(buffer) must match sizeof(Vectorized)"); + // We should be using memcpy in order to respect the strict aliasing rule + // see: https://github.com/pytorch/pytorch/issues/66119 + // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 + // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf) + const auto* a_data = a.as_bytes(); + const auto* b_data = b.as_bytes(); + // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t) + for (auto& out : buffer) { + out = op(load(a_data), load(b_data)); + a_data += sizeof(intmax_t); + b_data += sizeof(intmax_t); + } + assert(a_data == a.as_bytes() + sizeof(a)); + assert(b_data == b.as_bytes() + sizeof(b)); + return Vectorized::loadu(buffer); +} + +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_and()); +} +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_or()); +} +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_xor()); +} + +#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^) + +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator~(const Vectorized& a) { + using int_t = int_same_size_t; + Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 + return a ^ ones; +} + +template +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + constexpr T max_shift = sizeof(T) * CHAR_BIT; + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + T shift = b[i]; + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { + c[i] = 0; + } else { + c[i] = static_cast>(a[i]) << shift; + } + } + return c; +} + +template +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + // right shift value to retain sign bit for signed and no bits for unsigned + constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v; + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + T shift = b[i]; + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { + c[i] = a[i] >> max_shift; + } else { + c[i] = a[i] >> shift; + } + } + return c; +} + +template +inline Vectorized& operator+=(Vectorized& a, const Vectorized& b) { + a = a + b; + return a; +} +template +inline Vectorized& operator-=(Vectorized& a, const Vectorized& b) { + a = a - b; + return a; +} +template +inline Vectorized& operator/=(Vectorized& a, const Vectorized& b) { + a = a / b; + return a; +} +template +inline Vectorized& operator%=(Vectorized& a, const Vectorized& b) { + a = a % b; + return a; +} +template +inline Vectorized& operator*=(Vectorized& a, const Vectorized& b) { + a = a * b; + return a; +} + +template +inline Vectorized& operator<<=(Vectorized& a, const Vectorized& b) { + a = a << b; + return a; +} + +template +inline Vectorized& operator>>=(Vectorized& a, const Vectorized& b) { + a = a >> b; + return a; +} + +template +inline Vectorized fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b + c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd) + +template +inline Vectorized fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b - c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub) + +template +Vectorized inline operator&&( + const Vectorized& a, + const Vectorized& b) { + Vectorized ret; + for (int i = 0; i != Vectorized::size(); i++) { + ret[i] = a[i] && b[i]; + } + return ret; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&) + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + T>> inline gather(T const* base_addr, const Vectorized>& vindex) { + static constexpr int size = Vectorized::size(); + int_same_size_t index_arr[size]; + vindex.store(static_cast(index_arr)); + T buffer[size]; + for (const auto i : c10::irange(size)) { + buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; + } + return Vectorized::loadu(static_cast(buffer)); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + T const* base_addr, + const Vectorized>& vindex, + Vectorized& mask) { + static constexpr int size = Vectorized::size(); + T src_arr[size]; + int_same_size_t mask_arr[size]; // use int type so we can logical and + int_same_size_t index_arr[size]; + src.store(static_cast(src_arr)); + mask.store(static_cast(mask_arr)); + vindex.store(static_cast(index_arr)); + T buffer[size]; + for (const auto i : c10::irange(size)) { + if (mask_arr[i] & 0x01) { // check highest bit + buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; + } else { + buffer[i] = src_arr[i]; + } + } + mask = Vectorized(static_cast(0)); // "zero out" mask + return Vectorized::loadu(static_cast(buffer)); +} + +// Cast a given vector to another type without changing the bits representation. +// So a Vectorized of 512 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative +// 1s). A Vec of 256 bits containing all ones can be cast to a +// Vec of 256 bits containing all ones (i.e., four negative 1s). +// There is a struct here because we don't have static_if and I can't +// partially specialize a templated function. +template +struct CastImpl { + static inline Vectorized apply(const Vectorized& src) { + src_t src_arr[Vectorized::size()]; + src.store(static_cast(src_arr)); + return Vectorized::loadu(static_cast(src_arr)); + } +}; + +template +struct CastImpl { + static inline Vectorized apply(const Vectorized& src) { + return src; + } +}; + +template +inline Vectorized cast(const Vectorized& src) { + return CastImpl::apply(src); +} + +template > +inline Vectorized convert_to_int_of_same_size( + const Vectorized& src) { + static_assert(sizeof(T) == sizeof(IntType)); + static constexpr int size = Vectorized::size(); + + std::array src_arr = {}; + src.store(static_cast(src_arr.data())); + std::array buffer; + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) { + return static_cast(x); + }); + return Vectorized::loadu(static_cast(buffer.data())); +} + +template > +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src) { + static_assert(sizeof(T) == sizeof(IntType)); + static constexpr int size = Vectorized::size(); + + std::array src_arr; + src.store(static_cast(src_arr.data())); + std::array buffer; + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) { + return static_cast(x); + }); + return Vectorized::loadu(static_cast(buffer.data())); +} + +// clang-format off +// Example inputs for AVX512: +// a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// returns: +// Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// Example inputs for AVX2: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// clang-format on +template +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> +deinterleave2(const Vectorized& a, const Vectorized& b) { + static constexpr int size = Vectorized::size(); + static constexpr int half_size = size / 2; + T a_arr[size]; + T b_arr[size]; + T buffer1[size]; + T buffer2[size]; + a.store(static_cast(a_arr)); + b.store(static_cast(b_arr)); + for (const auto i : c10::irange(half_size)) { + buffer1[i] = a_arr[i * 2]; + buffer1[half_size + i] = b_arr[i * 2]; + buffer2[i] = a_arr[i * 2 + 1]; + buffer2[half_size + i] = b_arr[i * 2 + 1]; + } + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2) + +// clang-format off +// inverse operation of deinterleave2 +// Example inputs for AVX512: +// a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// returns, for AVX512: +// Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// Example inputs for AVX2 : a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// clang-format on +template +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> +interleave2(const Vectorized& a, const Vectorized& b) { + static constexpr int size = Vectorized::size(); + static constexpr int half_size = size / 2; + T a_arr[size]; + T b_arr[size]; + T buffer1[size]; + T buffer2[size]; + a.store(static_cast(a_arr)); + b.store(static_cast(b_arr)); + for (const auto i : c10::irange(half_size)) { + buffer1[i * 2] = a_arr[i]; + buffer1[i * 2 + 1] = b_arr[i]; + buffer2[i * 2] = a_arr[half_size + i]; + buffer2[i * 2 + 1] = b_arr[half_size + i]; + } + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2) + +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP +#undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC + +template +inline void convert(const src_T* src, dst_T* dst, int64_t n) { +#ifndef _MSC_VER +#pragma unroll +#endif + for ([[maybe_unused]] const auto i : c10::irange(n)) { + *dst = c10::convert(c10::load(src)); + src++; + dst++; + } +} + +template +inline Vectorized flip(const Vectorized& data) { + static constexpr int size = Vectorized::size(); + T output[size]; + T buffer[size]; + data.store(static_cast(buffer)); + for (const auto i : c10::irange(size)) { + output[i] = buffer[size - i - 1]; + } + return Vectorized::loadu(static_cast(output)); +} + +// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. +// `ld_src` is the leading dimension of `src` and `ld_dst` is the leading +// dimension of `dst`. +template +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst, + int M, + int N) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + dst[j * ld_dst + i] = src[i * ld_src + j]; + } + } +} + +template +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +// additional headers for more operations that depend on vec_base +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..c91858641cf2f7c2ffc154453499910d2c9f30d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + constexpr int count = std::min( + VectorizedN::size(), VectorizedN::size()); + __at_align__ src_t src_buf[VectorizedN::size()]; + src.store(src_buf); + __at_align__ dst_t dst_buf[VectorizedN::size()]; + for (int i = 0; i < count; i++) { + dst_buf[i] = static_cast(src_buf[i]); + } + return VectorizedN::loadu(dst_buf, count); + } +}; + +template +inline std::enable_if_t, Vectorized> convert( + const Vectorized& src) { + return src; +} + +template +inline std::enable_if_t, Vectorized> +convert(const Vectorized& src) { + return VecConvert::apply(src); +} + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + std::enable_if_t = 0> +inline VectorizedN convert(const VectorizedN& src) { + return VecConvert::apply(src); +} + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + bool keep = false, + std::enable_if_t = 0> +inline std::conditional_t, Vectorized> +convert(const VectorizedN& src) { + return VecConvert::apply(src); +} + +} // namespace CPU_CAPABILITY + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized&); + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline Vectorized convert_from_float( + const Vectorized&, + const Vectorized&); + +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h new file mode 100644 index 0000000000000000000000000000000000000000..972d3ee3929b9e30e3fda2e1ed706403277595c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h @@ -0,0 +1,156 @@ +#pragma once + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +static inline uint16_t float2half_scalar(float val) { +#if defined(CPU_CAPABILITY_AVX2) +#if defined(_MSC_VER) + __m256 v = _mm256_set1_ps(val); + __m128i o = + _mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return static_cast(_mm_cvtsi128_si32(o)); +#else + return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +#endif +#elif defined(CPU_CAPABILITY_AVX512) + __m512 v = _mm512_set1_ps(val); + __m256i o = + _mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return static_cast( + _mm_cvtsi128_si32(_mm256_castsi256_si128(o))); +#endif +} + +static inline float half2float_scalar(uint16_t val) { +#if defined(CPU_CAPABILITY_AVX2) +#if defined(_MSC_VER) + __m128i v = _mm_cvtsi32_si128(val); + __m256 o = _mm256_cvtph_ps(v); + return _mm256_cvtss_f32(o); +#else + return _cvtsh_ss(val); +#endif +#elif defined(CPU_CAPABILITY_AVX512) + __m256i v = + _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512 o = _mm512_cvtph_ps(v); + return _mm512_cvtss_f32(o); +#endif +} + +#endif + +// Transpose a [2, 32] matrix to [32, 2] +// Note: the output leading dimension should be 2, +// that is, the output must be contiguous +template > +static inline void transpose_pad_2x32_block( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int krem = 2, + int nrem = 32) { +#if defined(CPU_CAPABILITY_AVX512) + __m512i r0, r1; + __m512i d0, d1; + // load + if (nrem < 32) { + __mmask32 mask_krem_v = (1LL << nrem) - 1; + r0 = _mm512_maskz_loadu_epi16(mask_krem_v, src); + // if krem is not 2, pad with zeros + if (krem == 2) { + r1 = _mm512_maskz_loadu_epi16(mask_krem_v, src + ld_src); + } else { + r1 = _mm512_setzero_si512(); + } + } else { + r0 = _mm512_loadu_si512(reinterpret_cast(src)); + if (krem == 2) { + r1 = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); + } else { + r1 = _mm512_setzero_si512(); + } + } + // transpose + d0 = _mm512_unpacklo_epi16(r0, r1); + d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + + // store + if (nrem < 16) { + __mmask32 mask_rem_v = (1LL << (nrem * 2)) - 1; + _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); + } else if (nrem == 16) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + } else if (nrem < 32) { + __mmask32 mask_rem_v = (1LL << (nrem * 2 - 32)) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + _mm512_mask_storeu_epi16( + reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1); + } else { + // normal store + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); + } +#else + TORCH_CHECK( + false, + "transpose_pad_2x32_block is only supported when avx512 is supported") +#endif +} + +// To use AMX to accelerate GEMM, +// reorder the memory format [K, N] -> [K/2, N, 2] +// Note: If K % 2 != 0, pad K implicitly +template > +static inline void pack_vnni2( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int64_t K, + int64_t N) { +#if defined(CPU_CAPABILITY_AVX512) + int64_t bk = 0; + int64_t _K = K / 2 * 2; + int64_t _N = N / 32 * 32; + for (; bk < _K; bk += 2) { + int64_t bn = 0; + for (; bn < _N; bn += 32) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); + } + } + if (K % 2 == 1) { + int64_t bn = 0; + for (; bn < _N; bn += 32) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); + } + } +#else + TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") +#endif +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..e19d7f75388af07f43c6d47a7fdff5161d1b33a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h @@ -0,0 +1,300 @@ +#pragma once + +#include +#include +namespace at::vec { +inline namespace CPU_CAPABILITY { + +/** + * The `VecMask` class provides a convenient interface for working with + * vectorized masks in SIMD operations. It encapsulates a `Vectorized` + * mask that can be directly usable in masked vectorized operations. It provides + * various methods for manipulating and accessing the mask elements: + * 1. `from` and `to`: Conversion between a vector of boolean values and a + * vectorized mask. + * 2. `cast`: Casts the mask to a different base type. + * 3. `all_zero`: Checks if all mask elements are zero. + * 4. `is_masked`: Checks if a specific element is masked. + * 5. `loadu`: Loads data from memory using the mask. + * 6. `all_masked`: Checks if all mask elements are masked. + * + * Some helper template classes are provided to simplify the specialization of + * the `VecMask` for the specific CPU arch: + * 1. `VecMaskLoad`: Loads data from memory using the mask. + * 2. `VecMaskTo`: Converts the mask to boolean. + * 3. `VecMaskCast`: Casts the mask to a different base type. + * + */ +template +class VecMask; + +template < + typename data_t, + int data_n, + typename mask_t, + int mask_n, + typename Enabled = void> +struct VecMaskLoad { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + constexpr typename VecMask::size_type size = + VecMask::size(); + static_assert(VectorizedN::size() >= size); + __at_align__ data_t data[size]; + __at_align__ mask_t mask[size]; + auto mask_ = VectorizedN(vec_mask); + mask_.store(mask); + for (int i = 0; i < size; i++) { + data[i] = mask[i] ? ptr[i] : static_cast(0); + } + return VectorizedN::loadu(data, size); + } +}; + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecMaskTo { + static inline VecMask apply( + const VecMask& vec_mask) { + auto zeros = VectorizedN(static_cast(0)); + auto ones = VectorizedN(static_cast(1)); + return VectorizedN::blendv( + zeros, ones, vec_mask.template cast()); + } +}; + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecMaskCast { + static inline VecMask apply( + const VecMask& vec_mask) { + return VecMask::from(VectorizedN(vec_mask)); + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + return vec_mask; + } +}; + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m == static_cast(0); + }); + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m != static_cast(0); + }); + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return mask[i] != static_cast(0); + } +}; + +template +class VecMask { + public: + using size_type = int; + static constexpr size_type size() { + return VectorizedN::size(); + } + + private: + VectorizedN mask_; + + public: + VecMask() : mask_(static_cast(0)) {} + VecMask(const VectorizedN& mask) : mask_(mask) {} + + template = 0> + VecMask(const Vectorized& mask) : mask_(mask) {} + + template + static VecMask from(const VectorizedN& b_vec) { + __at_align__ U b_buf[size()]; + if constexpr (size() >= VectorizedN::size()) { + b_vec.store(b_buf); + for (int i = VectorizedN::size(); i < size(); i++) { + b_buf[i] = static_cast(0); + } + } else { + b_vec.store(b_buf, size()); + } + return from(b_buf); + } + + template + static VecMask from(U b) { + using int_t = int_same_size_t; + T mask = b ? c10::bit_cast((int_t)(~(int_t)0)) : (T)0; + return VectorizedN(mask); + } + + template + static VecMask from(U* b) { + using int_t = int_same_size_t; + __at_align__ T mask[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (int i = 0; i < size(); i++) { + *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; + } + return VectorizedN(VectorizedN::loadu(mask)); + } + + static VecMask blendv( + const VecMask& c, + const VecMask& b, + const VecMask& a) { + VectorizedN result = VectorizedN::blendv( + VectorizedN(c), VectorizedN(b), VectorizedN(a)); + return result; + } + + static VecMask set( + const VecMask& a, + const VecMask& b, + int64_t count = size()) { + VectorizedN result = VectorizedN::set( + VectorizedN(a), VectorizedN(b), count); + return result; + } + + void store(bool* b, int count = size()) { + constexpr int L = + (VectorizedN::size() + Vectorized::size() - 1) / + Vectorized::size(); + auto res = this->to(); + res.store(b, count); + return; + } + + template = 2, int> = 0> + inline VectorizedN to() const { + return VecMaskTo::apply(*this); + } + + template = 0> + inline Vectorized to() const { + return VecMaskTo::apply(*this); + } + + template + inline VecMask cast() const { + return VecMaskCast::apply(*this); + } + + inline bool all_zero() const { + return VecMaskCheck::all_zero(mask_); + } + + inline bool all_masked() const { + return VecMaskCheck::all_masked(mask_); + } + + inline bool is_masked(int i) const { + return VecMaskCheck::is_masked(mask_, i); + } + + inline operator VectorizedN() const { + return mask_; + } + + template = 0> + inline operator Vectorized() const { + return mask_[0]; + } + + inline Vectorized operator[](int i) const { + return mask_[i]; + } + + template < + typename U, + int L, + std::enable_if_t= 2 && VectorizedN::size() >= size(), int> = 0> + VectorizedN loadu(const U* ptr) const { + return VecMaskLoad::apply(ptr, *this); + } + + template < + typename U, + int L, + std::enable_if_t::size() >= size(), int> = 0> + Vectorized loadu(const U* ptr) const { + return VecMaskLoad::apply(ptr, *this); + } +}; + +#define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \ + template \ + inline VecMask op(const VecMask& a) { \ + return op(VectorizedN(a)); \ + } + +#define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \ + template < \ + typename T, \ + int N, \ + typename V, \ + int M, \ + std::enable_if_t::size() == VecMask::size(), int> = \ + 0> \ + inline VecMask op(const VecMask& a, const VecMask& b) { \ + return op( \ + VectorizedN(a), VectorizedN(b.template cast())); \ + } + +#define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \ + template < \ + typename T, \ + int N, \ + typename V, \ + int M, \ + std::enable_if_t::size() == VecMask::size(), int> = \ + 0> \ + inline VecMask op(const VecMask& a, const VecMask& b) { \ + return EXPR; \ + } + +VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator*) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b)) + +#undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL +#undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL +#undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL + +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h new file mode 100644 index 0000000000000000000000000000000000000000..9725bf3eedb09270c04e71dd5bb8388fe57a20e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h @@ -0,0 +1,405 @@ +#pragma once + +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +/** + * @brief A class template representing a vectorized type with + * `N * Vectorized::size()` elements, aiming to support vectors of + * arbitrary size. A specific use case of it is to represent vectors + * converted from data types with different sizes but with the same + * number of vector elements, e.g., `VectorizedN` can be + * a vector converted from two `Vectorized`, `VectorizedN` + * can be a vector converted from two `Vectorized` etc. + * + * It supports most of the operations of `Vectorized` + * and the implementation delegates to `Vectorized` with loops over `N`. + * + * @tparam T The underlying type of the vectorized elements. + * @tparam N The number of underlying `Vectorized`. + */ +template +class VectorizedN { + public: + using value_type = T; + using size_type = int; + + static constexpr size_type size_T = sizeof(T); + static constexpr size_type size() { + return Vectorized::size() * N; + } + + private: + std::array, N> values; + + public: + // methods not implemented yet: + // variadic constructor, operator T*, as_bytes, zero_mask + +#define VECTORIZEDN_DEFINE_UNARY_OP(op) \ + VectorizedN op() const { \ + return unary_op([](const Vectorized& a) { return a.op(); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP(op) \ + VectorizedN op(const VectorizedN& other) const { \ + return binary_op( \ + other, [](const Vectorized& a, const Vectorized& b) { \ + return a.op(b); \ + }); \ + } + + template + inline VectorizedN unary_op(Op op) const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i]); + } + return result; + } + + template + inline VectorizedN binary_op(const VectorizedN& other, Op op) + const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i], other.values[i]); + } + return result; + } + + template + inline VectorizedN ternary_op( + const VectorizedN& other, + const VectorizedN& other2, + Op op) const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i], other.values[i], other2.values[i]); + } + return result; + } + + VectorizedN() = default; + + explicit VectorizedN(T val) { + for (int i = 0; i < N; ++i) { + values[i] = Vectorized(val); + } + } + + template = 0> + VectorizedN(const Vectorized& val) : values({val}) {} + + template = 0> + VectorizedN(const Vectorized& val_0, const Vectorized& val_1) + : values({val_0, val_1}) {} + + template = 0> + inline operator Vectorized() const { + return values[0]; + } + + inline const Vectorized& operator[](int i) const { + return values[i]; + } + + inline Vectorized& operator[](int i) { + return values[i]; + } + + template + static VectorizedN blend( + const VectorizedN& a, + const VectorizedN& b) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = + Vectorized::template blend(a.values[i], b.values[i]); + } + return result; + } + + static VectorizedN blendv( + const VectorizedN& a, + const VectorizedN& b, + const VectorizedN& mask) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = + Vectorized::blendv(a.values[i], b.values[i], mask.values[i]); + } + return result; + } + + template + static VectorizedN arange( + T base = static_cast(0), + step_t step = static_cast(1)) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::arange(base, step); + base += step * Vectorized::size(); + } + return result; + } + + static VectorizedN set( + const VectorizedN& a, + const VectorizedN& b, + int64_t count = size()) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + if (count > 0) { + result.values[i] = Vectorized::set( + a.values[i], + b.values[i], + std::min(count, (int64_t)Vectorized::size())); + count -= Vectorized::size(); + } else { + result.values[i] = a.values[i]; + } + } + return result; + } + + static VectorizedN loadu(const void* ptr) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::loadu(ptr); + ptr = static_cast(ptr) + Vectorized::size(); + } + return result; + } + + static VectorizedN loadu(const void* ptr, int64_t count) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::loadu( + ptr, std::min(count, (int64_t)Vectorized::size())); + ptr = static_cast(ptr) + Vectorized::size(); + count -= Vectorized::size(); + if (count <= 0) { + break; + } + } + return result; + } + + void store(void* ptr) const { + for (int i = 0; i < N; ++i) { + values[i].store(ptr); + ptr = static_cast(ptr) + Vectorized::size(); + } + } + + void store(void* ptr, int count) const { + for (int i = 0; i < N; ++i) { + values[i].store(ptr, std::min(count, (int)Vectorized::size())); + ptr = static_cast(ptr) + Vectorized::size(); + count -= Vectorized::size(); + if (count <= 0) { + break; + } + } + } + + bool has_inf_nan() const { + for (int i = 0; i < N; ++i) { + if (values[i].has_inf_nan()) { + return true; + } + } + return false; + } + + VectorizedN map(T (*const f)(T)) const { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = values[i].map(f); + } + return result; + } + + VectorizedN map(T (*const f)(const T&)) const { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = values[i].map(f); + } + return result; + } + + VECTORIZEDN_DEFINE_UNARY_OP(isnan) + VECTORIZEDN_DEFINE_UNARY_OP(abs) + VECTORIZEDN_DEFINE_UNARY_OP(sgn) + VECTORIZEDN_DEFINE_UNARY_OP(angle) + VECTORIZEDN_DEFINE_UNARY_OP(real) + VECTORIZEDN_DEFINE_UNARY_OP(imag) + VECTORIZEDN_DEFINE_UNARY_OP(conj) + VECTORIZEDN_DEFINE_UNARY_OP(acos) + VECTORIZEDN_DEFINE_UNARY_OP(acosh) + VECTORIZEDN_DEFINE_UNARY_OP(asin) + VECTORIZEDN_DEFINE_UNARY_OP(asinh) + VECTORIZEDN_DEFINE_UNARY_OP(atan) + VECTORIZEDN_DEFINE_UNARY_OP(atanh) + VECTORIZEDN_DEFINE_BINARY_OP(atan2) + VECTORIZEDN_DEFINE_BINARY_OP(copysign) + VECTORIZEDN_DEFINE_UNARY_OP(erf) + VECTORIZEDN_DEFINE_UNARY_OP(erfc) + VECTORIZEDN_DEFINE_UNARY_OP(erfinv) + VECTORIZEDN_DEFINE_UNARY_OP(exp) + VECTORIZEDN_DEFINE_UNARY_OP(exp2) + VECTORIZEDN_DEFINE_UNARY_OP(expm1) + VECTORIZEDN_DEFINE_UNARY_OP(exp_u20) + VECTORIZEDN_DEFINE_UNARY_OP(frac) + VECTORIZEDN_DEFINE_BINARY_OP(fmod) + VECTORIZEDN_DEFINE_UNARY_OP(log) + VECTORIZEDN_DEFINE_UNARY_OP(log10) + VECTORIZEDN_DEFINE_UNARY_OP(log1p) + VECTORIZEDN_DEFINE_UNARY_OP(log2) + VECTORIZEDN_DEFINE_UNARY_OP(ceil) + VECTORIZEDN_DEFINE_UNARY_OP(cos) + VECTORIZEDN_DEFINE_UNARY_OP(cosh) + VECTORIZEDN_DEFINE_UNARY_OP(floor) + VECTORIZEDN_DEFINE_BINARY_OP(hypot) + VECTORIZEDN_DEFINE_UNARY_OP(i0) + VECTORIZEDN_DEFINE_UNARY_OP(i0e) + VECTORIZEDN_DEFINE_UNARY_OP(digamma) + VECTORIZEDN_DEFINE_BINARY_OP(igamma) + VECTORIZEDN_DEFINE_BINARY_OP(igammac) + VECTORIZEDN_DEFINE_UNARY_OP(neg) + VECTORIZEDN_DEFINE_BINARY_OP(nextafter) + VECTORIZEDN_DEFINE_UNARY_OP(round) + VECTORIZEDN_DEFINE_UNARY_OP(sin) + VECTORIZEDN_DEFINE_UNARY_OP(sinh) + VECTORIZEDN_DEFINE_UNARY_OP(tan) + VECTORIZEDN_DEFINE_UNARY_OP(tanh) + VECTORIZEDN_DEFINE_UNARY_OP(trunc) + VECTORIZEDN_DEFINE_UNARY_OP(lgamma) + VECTORIZEDN_DEFINE_UNARY_OP(sqrt) + VECTORIZEDN_DEFINE_UNARY_OP(reciprocal) + VECTORIZEDN_DEFINE_UNARY_OP(rsqrt) + VECTORIZEDN_DEFINE_BINARY_OP(pow) + VECTORIZEDN_DEFINE_BINARY_OP(operator==) + VECTORIZEDN_DEFINE_BINARY_OP(operator!=) + VECTORIZEDN_DEFINE_BINARY_OP(operator>=) + VECTORIZEDN_DEFINE_BINARY_OP(operator<=) + VECTORIZEDN_DEFINE_BINARY_OP(operator>) + VECTORIZEDN_DEFINE_BINARY_OP(operator<) + VECTORIZEDN_DEFINE_BINARY_OP(eq) + VECTORIZEDN_DEFINE_BINARY_OP(ne) + VECTORIZEDN_DEFINE_BINARY_OP(gt) + VECTORIZEDN_DEFINE_BINARY_OP(ge) + VECTORIZEDN_DEFINE_BINARY_OP(lt) + VECTORIZEDN_DEFINE_BINARY_OP(le) + +#undef VECTORIZEDN_DEFINE_UNARY_OP +#undef VECTORIZEDN_DEFINE_BINARY_OP +}; + +#define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op(const VectorizedN& a) { \ + return a.unary_op([](const Vectorized& a) { return op(a); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op( \ + const VectorizedN& a, const VectorizedN& b) { \ + return a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \ + return op(a, b); \ + }); \ + } + +#define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op( \ + const VectorizedN& a, \ + const VectorizedN& b, \ + const VectorizedN& c) { \ + return a.ternary_op( \ + b, \ + c, \ + [](const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& c) { return op(a, b, c); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \ + template \ + inline VectorizedN& op( \ + VectorizedN& a, const VectorizedN& b) { \ + a = a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \ + return op(a, b); \ + }); \ + return a; \ + } + +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmadd) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmsub) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(clamp) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^) +VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~) + +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=) + +#undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL +#undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL +#undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL + +template +inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN acc_vec) { + Vectorized vec_result = acc_vec[0]; + for (int i = 1; i < N; i++) { + vec_result = vec_fun(vec_result, acc_vec[i]); + } + return vec_reduce_all(vec_fun, vec_result); +} + +template +std::ostream& operator<<(std::ostream& stream, const VectorizedN& vec_n) { + stream << "vec_n["; + for (int i = 0; i < N; ++i) { + if (i != 0) { + stream << ", "; + } + stream << vec_n[i]; + } + stream << ']'; + return stream; +} +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h new file mode 100644 index 0000000000000000000000000000000000000000..c64643546a2c1097a7a323dafc6cf5079d1b2fd9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include +#include + +#include + +// Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b0b1412298d7b65bd0354d234bbddd09f19031a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh @@ -0,0 +1,47 @@ +#include + +#include + +namespace at::cuda { + +/** + Computes ceil(a / b) +*/ +template +__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) { + return (a + b - 1) / b; +} + +namespace { + +// Threads per block for our apply kernel +// FIXME: use occupancy calculator instead +constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512; +constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4; + +template +inline bool getApplyGrid(uint64_t totalElements, dim3& grid, c10::DeviceIndex curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) { + if (curDevice == -1) return false; + uint64_t numel_per_thread = static_cast(max_threads_per_block) * static_cast(step); + uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread); + uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0]; + if (numBlocks > maxGridX) + numBlocks = maxGridX; + grid = dim3(numBlocks); + return true; +} + +constexpr int getApplyBlocksPerSM() { + return AT_APPLY_BLOCKS_PER_SM; +} + +constexpr int getApplyBlockSize() { + return AT_APPLY_THREADS_PER_BLOCK; +} + +inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) { + return dim3(max_threads_per_block); +} + +} // anonymous namespace +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/AsmUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/AsmUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8bd897e64c4fdcdf6bd32c0b176ce2414ebe9438 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/AsmUtils.cuh @@ -0,0 +1,149 @@ +#pragma once +#include + +// Collection of direct PTX functions + +namespace at::cuda { + +template +struct Bitfield {}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + unsigned int getBitfield(unsigned int val, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + return (val >> pos) & m; +#else + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); + return ret; +#endif + } + + static __device__ __host__ __forceinline__ + unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; +#else + unsigned int ret; + asm("bfi.b32 %0, %1, %2, %3, %4;" : + "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len)); + return ret; +#endif + } +}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + uint64_t getBitfield(uint64_t val, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + return (val >> pos) & m; +#else + uint64_t ret; + asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); + return ret; +#endif + } + + static __device__ __host__ __forceinline__ + uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; +#else + uint64_t ret; + asm("bfi.b64 %0, %1, %2, %3, %4;" : + "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len)); + return ret; +#endif + } +}; + +__device__ __forceinline__ int getLaneId() { +#if defined(USE_ROCM) + return __lane_id(); +#else + int laneId; + asm("mov.s32 %0, %%laneid;" : "=r"(laneId) ); + return laneId; +#endif +} + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskLt() { + const std::uint64_t m = (1ull << getLaneId()) - 1ull; + return m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskLt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); + return mask; +} +#endif + +#if defined (USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskLe() { + std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); + return m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskLe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); + return mask; +} +#endif + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskGt() { + const std::uint64_t m = getLaneMaskLe(); + return m ? ~m : m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskGt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); + return mask; +} +#endif + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskGe() { + const std::uint64_t m = getLaneMaskLt(); + return ~m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskGe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); + return mask; +} +#endif + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Atomic.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Atomic.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f16be30f8b71b7c0eb0bdb962fabbc865fc72f84 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Atomic.cuh @@ -0,0 +1,525 @@ +#pragma once + +#include +#include +#include + +#include + +#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + +template +struct AtomicFPOp; + +template <> +struct AtomicFPOp { + template + inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) { + unsigned int * address_as_ui = + (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + at::Half hsum; + do { + assumed = old; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + hsum = func(hsum, val); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + return hsum; + } +}; + +template <> +struct AtomicFPOp { + template + inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) { + unsigned int * address_as_ui = + (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + at::BFloat16 bsum; + do { + assumed = old; + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + bsum = func(bsum, val); + old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + return bsum.x; + } +}; + +template <> +struct AtomicFPOp { + template + inline __device__ double operator() (double * address, double val, const func_t& func) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull; + unsigned long long int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, func(val, assumed)); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); + } +}; + +#define ATOMIC_INTEGER_IMPL(NAME) \ +template \ +struct Atomic##NAME##IntegerImpl; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 3; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = offset * 8; \ + uint32_t old_byte; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_byte = (old >> shift) & 0xff; \ + newval = static_cast(func(val, static_cast(old_byte))); \ + newval = (old & ~(0x000000ff << shift)) | (newval << shift); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + size_t offset = (size_t)address & 2; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + bool is_32_align = offset; \ + uint32_t old = *address_as_ui; \ + uint32_t old_bytes; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_bytes = is_32_align ? old >> 16 : old & 0xffff; \ + newval = static_cast(func(val, static_cast(old_bytes))); \ + newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + uint32_t * address_as_ui = (uint32_t *) (address); \ + uint32_t old = *address_as_ui; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; \ + \ +template \ +struct Atomic##NAME##IntegerImpl { \ + template \ + inline __device__ void operator()(T *address, T val, const func_t& func) { \ + unsigned long long * address_as_ui = (unsigned long long *) (address); \ + unsigned long long old = *address_as_ui; \ + unsigned long long newval; \ + unsigned long long assumed; \ + \ + do { \ + assumed = old; \ + newval = static_cast(func(val, static_cast(old))); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ + } \ +}; + + +# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \ +inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \ +Atomic##NAME##IntegerImpl()(address, \ + val, \ + [](DTYPE a, DTYPE b) { \ + return OP; \ + }); \ +} \ + +ATOMIC_INTEGER_IMPL(Add) +GPU_ATOMIC_INTEGER(Add, a || b, bool) + +// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64) +inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) { + AtomicAddIntegerImpl()(address, + val, + [](uint8_t a, uint8_t b) { + return a + b; + }); +} + +inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) { + AtomicAddIntegerImpl()(address, + val, + [](int8_t a, int8_t b) { + return a + b; + }); +} + +inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) { + AtomicAddIntegerImpl()(address, + val, + [](int16_t a, int16_t b) { + return a + b; + }); +} + +inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) { + return atomicAdd(address, val); +} + +inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) { +#if defined(USE_ROCM) + __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +#else + static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed"); + atomicAdd(reinterpret_cast(address), static_cast(val)); +#endif +} + +inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) { +#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + return AtomicFPOp()(address, val, + [](at::Half hsum, at::Half val) { + return hsum + val; + }); +#else + return atomicAdd(reinterpret_cast<__half*>(address), val); +#endif +} + +inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) { +#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) +return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum + val; + }); +#else + __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val)); + return *reinterpret_cast(&r); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600) +// from CUDA C Programmic Guide +inline __device__ double atomicAdd(double* address, double val) +#if defined(__clang__) && defined(__CUDA__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wgcc-compat" + __attribute__((enable_if(true, ""))) +#pragma GCC diagnostic pop +#endif +{ + + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(val + __longlong_as_double(assumed)); + }); +} +#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__)) + +/* Note [hip-clang differences to hcc] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * The upcoming hip-clang compiler for ROCm differs from hcc in a few details. + * It exports the __HIP__ macro, we can hence differentiate between hcc and + * hip-clang. In the below, hcc only received support for atomicAdd with double + * typing after work week 18312. hip-clang had support from the first version. + * In general, the code-visible differences between hip-clang and hcc will be + * minimal. + */ + +#if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__ + // This needs to be defined for the host side pass + inline __device__ double atomicAdd(double *address, double val) { } +#endif +#endif + +inline __device__ double gpuAtomicAdd(double *address, double val) { + return atomicAdd(address, val); +} + +inline __device__ float gpuAtomicAdd(float *address, float val) { + return atomicAdd(address, val); +} + +template +inline __device__ void gpuAtomicAdd(c10::complex *address, c10::complex val) { + gpuAtomicAdd(&address->real_, val.real_); + gpuAtomicAdd(&address->imag_, val.imag_); +} + +/* Note [gpuAtomicAdd vs atomicAdd] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Some extensions such as torchvision call atomicAdd() + * directly and require non-library provided data type support. Only for these, we + * continue to provide atomicAdd overloads. + */ +inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { + return gpuAtomicAdd(address, val); +} + +inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) { + return gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(uint8_t *address, uint8_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(int8_t *address, int8_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(int16_t *address, int16_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(int64_t *address, int64_t val) { + gpuAtomicAdd(address, val); +} + +inline __device__ void atomicAdd(bool *address, bool val) { + gpuAtomicAdd(address, val); +} + +/* Note [explicitly non-returning atomics] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet(). + * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction. + * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd, + * therefore we need a new API 'gpuAtomicAddNoReturn'. + */ +template +inline __device__ void gpuAtomicAddNoReturn(c10::complex *address, c10::complex val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); } + +/* Note [HIP unsafeAtomicAdd] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Use unsafeAtomicAdd instead of atomicAdd for fp32 and fp64. + * On HIP, atomicAdd is always correct but is a slow CAS loop. + * unsafeAtomicAdd will use HW instructions and is much faster, + * but the caller must guarantee the pointer is GPU memory. + * If the pointer is system memory, the result is a silent no-op. + * This guarantee is upheld by all PyTorch uses of unsafeAtomicAdd. + * AMD HIP atomic header file is named amd_hip_atomic.h and is + * under the LLVM compiler directory. + */ +#if defined(USE_ROCM) +inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { +#if defined(__gfx908__) + atomicAddNoRet(address, val); +#else + (void)unsafeAtomicAdd(address, val); +#endif +} +inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { (void)unsafeAtomicAdd(address, val); } +#else +inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); } +inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); } +#endif + +// Atomic multiplication implementation. + +ATOMIC_INTEGER_IMPL(Mul) +GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int8_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int16_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int32_t) +GPU_ATOMIC_INTEGER(Mul, a * b, int64_t) + +inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return bsum * val; + }); +} + +inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum * val; + }); +} + +inline __device__ double gpuAtomicMul(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(val * __longlong_as_double(assumed)); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMul (float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(val * + __int_as_float(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} + +// Atomic maximum implementation. + +template +__host__ __device__ T safe_max(T a, T b) { + #if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b)); + #else + T max = at::_isnan(b) ? b : std::max(a, b); + #endif + + return max; +} + +ATOMIC_INTEGER_IMPL(Max) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t) +GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t) + +inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return safe_max(bsum, val); + }); +} + +inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return safe_max(bsum, val); + }); +} + +inline __device__ double gpuAtomicMax(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(safe_max(val, __longlong_as_double(assumed))); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMax(float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(safe_max(val, __int_as_float(assumed)))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} + +// Atomic minimum implementation. + +template +__host__ __device__ T safe_min(T a, T b) { + #if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b)); + #else + T min = at::_isnan(b) ? b : std::min(a, b); + #endif + + return min; +} + +ATOMIC_INTEGER_IMPL(Min) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t) +GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t) + +inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) { + return AtomicFPOp()(address, val, + [](at::Half bsum, at::Half val) { + return safe_min(bsum, val); + }); +} + +inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) { + return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return safe_min(bsum, val); + }); +} + +inline __device__ double gpuAtomicMin(double * address, double val) { + return AtomicFPOp()(address, val, + [](double val, unsigned long long int assumed) { + return __double_as_longlong(safe_min(val, __longlong_as_double(assumed))); + }); +} + +// Dont use a templated function for this since the addition function defaults to the CUDA built-in. +inline __device__ float gpuAtomicMin(float * address, float val) { + unsigned int* address_as_ull = (unsigned int*)address; + unsigned int old = *address_as_ull; + unsigned int assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(safe_min(val, __int_as_float(assumed)))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __int_as_float(old); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4dcdabf17b3b90ad598db48f7210e3932142aaad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh @@ -0,0 +1,537 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// +// This file contains pointwise operation functions and kernels that +// work on both contiguous and non-contiguous tensor arguments of +// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without +// copying or temporary storage. +// + +/* + NOTE [ CUDA_tensor_applyN helpers ] + + The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4) + functions apply a pointwise operator to N tensor(s). + + The calling convention is + + 1. The template arguments should be, sequentially, + - First N typename args specify the scalar types of each of the N tensors. + - (Optional) `int step` arg specifies the number of elements processed + together at the same time. + Default is 1. + - A usually omitted (i.e., inferred) typename arg specifies the type of the + function/functor applied on `N * step` values in each iteration of each + CUDA thread. + 2. The arguments should be, sequentially, + - N tensors + - op: a function/functor that processes `N * step` values at the same time. + - If `step == 1`, it must have signature + `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where + `scalar*_t`s are the first N typename template args, and the inputs + are the `N` values from the `N` tensors retrieved at a common index. + - Otherwise, it must must have signature + void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times + scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times + ..., + scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times + Different from `step == 1` case, it processes `N * step` values taken + from `step` common indices. Moreover, the first input `n` represents the + number of valid indices (it will always have `0 < n <= step`). It will + almost always be `step`, but at the boundary we may not have full `step` + elements and `n` can be a lesser value. + + E.g., if `step == 4` and `N == 2`, `op` could be + + [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4, + scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) { + // Only process u1, ..., un and v1, ..., vn. + // So if `n == 3`, `u4` and `v4` need not to be considered. + } + + In both cases, the references can actually be const, but at least one of + them should be non-const in order to write the output. + - (Optional, but recommended) N TensorArgType args that specify for each + tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite), + or only reads (i.e., TensorArgType::ReadOnly). + Default is TensorArgType::ReadWrite for first Tensor, and + TensorArgType::ReadOnly for the rest. + + E.g., + + to compute a = b^2 for a and b of same dtype, we can call + + CUDA_tensor_apply2( + a, b, + [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; } + ); + + to work on 2 values at the same time, we can call + + CUDA_tensor_apply2( + a, b, + [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2, + const scalar2 &b_val1, const scalar2 &b_val2) { + // call special vectorized op here, or just do elementwise and enjoy unrolling... + // if n == 1, only process a_val1 and b_val1 + } + ); +*/ + +namespace at::cuda { + +// TODO: combine with TensorArg? So far that's been for debugging, and this is functional... +enum class TensorArgType { ReadWrite, ReadOnly }; + +namespace { + +// Rearrange dimensions for pointwise operations so that strides are in +// decreasing order as much as possible, so that kernels have better memory +// access patterns. +// +// For example, consider a binary operation on two "transposed" 2-dim tensors: +// sizes: 256 512 +// aInfo->strides: 1 256 +// bInfo->strides: 1 256 +// +// Given this, each concurrent memory access inside kernelPointwiseApply2() is +// exactly 256 elements apart, resulting in poor performance. +// +// This function exchanges dimensions so that memory access is contiguous: +// sizes: 512 256 +// aInfo->strides: 256 1 +// bInfo->strides: 256 1 +// +// (Actually, it becomes even better because now collapseDims() can turn each +// input into one contiguous array.) +// +// In general, given M (<=4) TensorInfo's with N dimensions, we can view each +// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange +// strides[i] and [j] if +// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M) +// (exchanging them will benefit input #k), and +// (2) strides[i][k] <= strieds[j][k] for all k +// (exchanging them will not make any input worse). +template +inline void rearrangeDims(detail::TensorInfo* aInfo, + detail::TensorInfo* bInfo = nullptr, + detail::TensorInfo* cInfo = nullptr, + detail::TensorInfo* dInfo = nullptr) { + int numInfos = 1; + int dims = aInfo->dims; + IndexType *sizes[4] = { aInfo->sizes, }; + IndexType *strides[4] = { aInfo->strides, }; + + if (bInfo != nullptr) { + ++numInfos; + if (bInfo->dims != dims) return; + sizes[1] = bInfo->sizes; + strides[1] = bInfo->strides; + } + + if (cInfo != nullptr) { + ++numInfos; + if (cInfo->dims != dims) return; + sizes[2] = cInfo->sizes; + strides[2] = cInfo->strides; + } + + if (dInfo != nullptr) { + ++numInfos; + if (dInfo->dims != dims) return; + sizes[3] = dInfo->sizes; + strides[3] = dInfo->strides; + } + + // Bail out if sizes do not match: we are using "deprecated pointwise + // behavior" among tensors of different shapes but same number of elements. + for (int i = 1; i < numInfos; ++i) { + for (int j = 0; j < dims; ++j) { + if (sizes[i][j] != sizes[0][j]) return; + } + } + + for (int i = 0; i < dims - 1; ++i) { + // No need to consider dimensions of size 1. + if (sizes[0][i] == 1) continue; + + for (int j = i + 1; j < dims; ++j) { + if (sizes[0][j] == 1) continue; + + // Compare the relative sizes of strides between dim #i and dim #j. + bool hasIncreasingStrides = false; + bool hasDecreasingStrides = false; + + for (int k = 0; k < numInfos; k++) { + IndexType stride_i = strides[k][i]; + IndexType stride_j = strides[k][j]; + if (stride_i < stride_j) { + hasIncreasingStrides = true; + } else if (stride_i > stride_j) { + hasDecreasingStrides = true; + } + } + + if (hasIncreasingStrides && !hasDecreasingStrides) { + for (int k = 0; k < numInfos; k++) { + IndexType size = sizes[k][i]; + sizes[k][i] = sizes[k][j]; + sizes[k][j] = size; + + IndexType stride = strides[k][i]; + strides[k][i] = strides[k][j]; + strides[k][j] = stride; + } + } + } + } +} + +// The `remaining_steps` argument is used to support Op that operates on +// multiple elements at the same time. Generally, the strategy of ApplyOpN is to +// 1. Initialize `remaining_steps = step`, where `step` is the template arg of +// CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the +// number of elements in bound for this call. It will almost always equal to +// `step` except at boundaries. +// 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in +// bound), and recursively call `ApplyOpN` with `remaining_steps - 1`. +// 3. At `remaining_steps = 0`, +// if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`; +// if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep, +// tensor2_val1, tensor2_val2, ..., tesor2_valstep, +// ... +// tensorN_val1, tensorN_val2, ..., tesorN_valstep);` +// +// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like. + +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, int n, + IndexType linearIndex, Offsets... aOffsets) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = sizeof...(Offsets) < n ? + detail::IndexToOffset::get(linearIndex, a) : 0; + + ApplyOp1::apply( + a, op, n, linearIndex + 1, aOffsets..., aOffset + ); +} +}; + +// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). +// We don't need to pass in how many elements need to processed in this case. +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, + int n, IndexType linearIndex, Offset offset) { + op(a.data[offset]); +} +}; + +template +struct ApplyOp1 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, const Op &op, int n, + IndexType linearIndex, Offsets... offsets) { + op(n, a.data[offsets]...); +} +}; + +template +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) +C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) +#endif +__global__ void kernelPointwiseApply1(detail::TensorInfo a, + IndexType totalElements, const Op op) { + for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x * step) { + ApplyOp1::apply( + a, op, ::min(step, static_cast(totalElements - linearIndex)), linearIndex); + } +} + + +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int64_t n, IndexType linearIndex, + Offsets... aOffsets, Offsets... bOffsets) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = static_cast(sizeof...(Offsets)) < n ? + detail::IndexToOffset::get(linearIndex, a) : 0; + + // Convert `linearIndex` into an offset of `b` + const IndexType bOffset = static_cast(sizeof...(Offsets)) < n ? + detail::IndexToOffset::get(linearIndex, b) : 0; + + ApplyOp2::apply( + a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset + ); +} +}; + +// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). +// We don't need to pass in how many elements need to processed in this case. +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int /*n*/, IndexType /*linearIndex*/, + Offset aOffset, Offset bOffset) { + op(a.data[aOffset], b.data[bOffset]); +} +}; + +template +struct ApplyOp2 { +__device__ __forceinline__ +static void apply(detail::TensorInfo &a, + detail::TensorInfo &b, + const Op &op, int n, IndexType linearIndex, + Offsets... aOffsets, Offsets... bOffsets) { + op(n, a.data[aOffsets]..., b.data[bOffsets]...); +} +}; + +template +#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) +C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) +#endif +__global__ void +kernelPointwiseApply2(detail::TensorInfo a, + detail::TensorInfo b, + IndexType totalElements, + const Op op) { + for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x * step) { + ApplyOp2::apply( + a, b, op, ::min(step, static_cast(totalElements - linearIndex)), + linearIndex); + } +} + +} // anonymous namespace + +template +inline bool CUDA_tensor_apply2(at::TensorBase a, + at::TensorBase b, + const Op op, + TensorArgType aType = TensorArgType::ReadWrite, + TensorArgType bType = TensorArgType::ReadOnly) { + TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(), + "CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got " + "tensors with type ", a.device().type(), " and ", b.device().type()); + int64_t totalElements = a.numel(); + + if (totalElements != b.numel()) { + return false; + } + + if (a.dim() > MAX_TENSORINFO_DIMS || + b.dim() > MAX_TENSORINFO_DIMS) { + return false; + } + + if (a.numel() == 0) { + // Empty tensor; do nothing + return true; + } + const dim3 block = getApplyBlock(max_threads_per_block); + + dim3 grid; + auto curDevice = current_device(); + if (curDevice == -1) return false; + if (!getApplyGrid(totalElements, grid, curDevice, max_threads_per_block)) { + return false; + } + + /* + Expands readable/writable tensors whose indices may be "overlapped." + This ensures that each element of the tensor is operated on once and only + once. + */ + TensorBase oldA; + TensorBase oldB; + + if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { + // Must perform in contiguous space + oldA = std::exchange(a, a.contiguous()); + } + if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) { + // Must perform in contiguous space + oldB = std::exchange(b, b.contiguous()); + } + + // It is possible that the tensor dimensions are able to be collapsed, + // and thus we can reduce the actual code complexity of the copy by + // exploiting this knowledge statically, since the div/mod is the + // most expensive part of the operation, more so than memory accesses. + // For instance, when copying a non-contiguous to a contiguous tensor + // (or vice versa), the contiguous tensor can be collapsed to one + // dimension, and the loop to translate the linear index to the array + // index can be similarly collapsed. That is what this unrolling is for. + +#define HANDLE_CASE(TYPE, A, B) \ + kernelPointwiseApply2 \ + <<>>( \ + aInfo, bInfo, static_cast(totalElements), op); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#define HANDLE_B_CASE(TYPE, A, B) { \ + switch (B) { \ + case 1: \ + HANDLE_CASE(TYPE, A, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, A, 2); \ + break; \ + default: \ + HANDLE_CASE(TYPE, A, -1); \ + break; \ + } \ +} + +#define HANDLE_A_CASE(TYPE, A, B) { \ + switch (A) { \ + case 1: \ + HANDLE_B_CASE(TYPE, 1, B); \ + break; \ + case 2: \ + HANDLE_B_CASE(TYPE, 2, B); \ + break; \ + default: \ + HANDLE_B_CASE(TYPE, -1, B); \ + break; \ + } \ +} + + if (detail::canUse32BitIndexMath(a) && + detail::canUse32BitIndexMath(b)) { + detail::TensorInfo aInfo = + detail::getTensorInfo(a); + + detail::TensorInfo bInfo = + detail::getTensorInfo(b); + rearrangeDims(&aInfo, &bInfo); + aInfo.collapseDims(); + bInfo.collapseDims(); + + HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims); + } else { + detail::TensorInfo aInfo = + detail::getTensorInfo(a); + + detail::TensorInfo bInfo = + detail::getTensorInfo(b); + rearrangeDims(&aInfo, &bInfo); + aInfo.collapseDims(); + bInfo.collapseDims(); + + /* + Only instantiates the all 1D special case and the fallback all nD case for + large (64-bit indexed) tensors to reduce compilation time. + */ + if (aInfo.dims == 1 && bInfo.dims == 1) { + HANDLE_CASE(uint64_t, 1, 1); + } else { + HANDLE_CASE(uint64_t, -1, -1); + } + } +#undef HANDLE_CASE +#undef HANDLE_B_CASE +#undef HANDLE_A_CASE + + if (oldA.defined()) { + at::native::copy_ignoring_overlaps(oldA, a); + } + + if (oldB.defined()) { + at::native::copy_ignoring_overlaps(oldB, b); + } + + return true; +} + +/* Provides default step = 1 to CUDA_tensor_apply2. */ +template +inline bool CUDA_tensor_apply2(const at::TensorBase &a, + const at::TensorBase &b, + const Op op, + TensorArgType aType = TensorArgType::ReadWrite, + TensorArgType bType = TensorArgType::ReadOnly) { + return CUDA_tensor_apply2(a, b, op, aType, bType); +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDABlas.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDABlas.h new file mode 100644 index 0000000000000000000000000000000000000000..b1dac2162dc421c89d555af6a8b475adbcc1bdcb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDABlas.h @@ -0,0 +1,417 @@ +#pragma once +/* + Provides a subset of CUDA BLAS functions as templates: + + gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc) + + gemv(transa, m, n, alpha, a, lda, x, incx, beta, y, incy) + + dot(n, x, incx, y, incy, result) + + where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot). + The functions are available in at::cuda::blas namespace. + */ + +#include +#include + +namespace at::cuda::blas { + +// RAII guard that sets the CuBLAS pointer mode and restores it to +// its previous value when the guard is destroyed +class PointerModeGuard { +public: + PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) : + handle(handle) { + TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode)); + TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode)); + } + + ~PointerModeGuard() { + cublasSetPointerMode(handle, previous_mode); + } + +private: + cublasHandle_t handle; + cublasPointerMode_t previous_mode{}; +}; + +/* LEVEL 3 BLAS FUNCTIONS */ + +#define CUDABLAS_GEMM_ARGTYPES(Dtype) CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) + +#define CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \ + char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type alpha, \ + const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type beta,\ + C_Dtype *c, int64_t ldc + +#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc + +#define CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT \ + ((std::is_same::value || std::is_same::value) && std::is_same::value) + +template ::type* = nullptr> +inline void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented"); +} + +template ::type* = nullptr> +void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)); + +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); +template<> +void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +template +inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented"); +} + +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); +template<> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +enum GEMMAndBiasActivationEpilogue { + None, + RELU, + GELU, +}; + +// NOTE: GELU activation is not supported prior to CUDA 11.4 and will +// do nothing if passed in that case. +template +bool gemm_and_bias( + bool transpose_mat1, + bool transpose_mat2, + int64_t m, + int64_t n, + int64_t k, + at::opmath_type alpha_val, + const Dtype* mat1_ptr, + int64_t mat1_ld, + const Dtype* mat2_ptr, + int64_t mat2_ld, + const Dtype* bias, + C_Dtype* result_ptr, + int64_t result_ld, + GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None); + +void int8_gemm( + bool transpose_mat1, + bool transpose_mat2, + int64_t m, + int64_t n, + int64_t k, + const int8_t* mat1_ptr, + int64_t mat1_ld, + const int8_t* mat2_ptr, + int64_t mat2_ld, + int32_t* result_ptr, + int64_t result_ld); + +void scaled_gemm( + char transa, + char transb, + int64_t m, + int64_t n, + int64_t k, + const void* mat1_ptr, + const void* mat1_scale_ptr, + int64_t mat1_ld, + ScalarType mat1_dtype, + ScalarType mat1_scale_dtype, + const void* mat2_ptr, + const void* mat2_scale_ptr, + int64_t mat2_ld, + ScalarType mat2_dtype, + ScalarType mat2_scale_dtype, + const void* bias_ptr, + ScalarType bias_dtype, + void* result_ptr, + const void* result_scale_ptr, + int64_t result_ld, + ScalarType result_dtype, + bool use_fast_accum, + bool use_rowwise); + +#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) + +#define CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \ + char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type alpha, \ + const Dtype *a, int64_t lda, int64_t stridea, \ + const Dtype *b, int64_t ldb, int64_t strideb, \ + at::opmath_type beta, C_Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches + +#define CUDABLAS_BGEMM_ARGS(Dtype) \ + transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches + +template ::type* = nullptr> +inline void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented"); +} + +template ::type* = nullptr> +void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)); + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +template<> +void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +template +inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented"); +} + +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)); +template <> +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +template<> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)); +template<> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)); + +#define CUDABLAS_TRSM_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \ + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \ + const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb + +template +inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented"); +} + +template <> +TORCH_CUDA_CU_API void trsm(CUDABLAS_TRSM_ARGTYPES(float)); +template <> +TORCH_CUDA_CU_API void trsm(CUDABLAS_TRSM_ARGTYPES(double)); +template <> +TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::complex)); +template <> +TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::complex)); + +#define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \ + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \ + const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \ + int batchCount + +template +inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented"); +} + +template <> +TORCH_CUDA_CU_API void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)); +template <> +TORCH_CUDA_CU_API void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)); +template <> +TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex)); +template <> +TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex)); + +/* LEVEL 2 BLAS FUNCTIONS */ + +#define CUDABLAS_GEMV_ARGTYPES(Dtype) \ + char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \ + const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy + +template +inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented"); +} + +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(double)); +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(float)); +template <> +void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); +template <> +void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(at::Half)); +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); + +/* LEVEL 1 BLAS FUNCTIONS */ + +#define CUDABLAS_DOT_ARGTYPES(Dtype) \ + cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \ + int incy, Dtype *result + +template +inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented"); +} + +template <> +void dot(CUDABLAS_DOT_ARGTYPES(double)); +template <> +void dot(CUDABLAS_DOT_ARGTYPES(float)); +template <> +void dot(CUDABLAS_DOT_ARGTYPES(at::Half)); +template <> +void dot(CUDABLAS_DOT_ARGTYPES(at::BFloat16)); +template <> +void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); +template <> +void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); + +template +inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented"); +} + +template <> +void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); +template <> +void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); + +#define CUDABLAS_GETRS_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasOperation_t trans, \ + int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \ + Dtype** dB_array, int ldb, int* info_array, int batchsize + +#define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \ + cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \ + Dtype **tau_array, int *info, int batchsize + +#define CUDABLAS_GETRF_ARGTYPES(Dtype) \ + int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize + +#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \ + cublasHandle_t handle, cublasOperation_t trans, \ + int m, int n, int nrhs, Dtype** dA_array, int ldda, \ + Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize + +// HIP on Windows does not support getrs, geqrf, getrf, gels +#if !(defined(USE_ROCM) && defined(_MSC_VER)) + +template +void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented"); +} +template<> +TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(float)); +template<> +TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(double)); +template<> +TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex)); +template<> +TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex)); + +template +void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented"); +} +template <> +TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)); +template <> +TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)); +template <> +TORCH_CUDA_CU_API void geqrfBatched>( + CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)); +template <> +TORCH_CUDA_CU_API void geqrfBatched>( + CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex)); + +template +void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::getrfBatched: not implemented"); +} +template<> +TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float)); +template<> +TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(double)); +template<> +TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); +template<> +TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex)); + +template +void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype), "at::cuda::blas::gelsBatched: not implemented"); +} +template<> +TORCH_CUDA_CU_API void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(double)); +template<> +TORCH_CUDA_CU_API void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(float)); +template<> +TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex)); +template<> +TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex)); + +#else // !(defined(USE_ROCM) && defined(_MSC_VER)) + +template +void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::getrsBatched: not supported for HIP on Windows"); +} + +template +void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::geqrfBatched: not supported for HIP on Windows"); +} + +template +void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not supported for HIP on Windows"); +} + +template +void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::gelsBatched: not supported for HIP on Windows"); +} + +#endif // !(defined(USE_ROCM) && defined(_MSC_VER)) + +} // namespace at::cuda::blas diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAConfig.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAConfig.h new file mode 100644 index 0000000000000000000000000000000000000000..24c1dbe367b7204d983af93606f0ff6c95fd69d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAConfig.h @@ -0,0 +1,20 @@ +#pragma once + +// Test these using #if AT_CUDNN_ENABLED(), not #ifdef, so that it's +// obvious if you forgot to include Config.h +// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined +// +// NB: This header MUST NOT be included from other headers; it should +// only be included from C++ files. +#define AT_CUDNN_ENABLED() 1 +#define AT_CUSPARSELT_ENABLED() 1 +#define AT_HIPSPARSELT_ENABLED() 0 +#define AT_ROCM_ENABLED() 0 +#define AT_MAGMA_ENABLED() 1 + +// Needed for hipMAGMA to correctly identify implementation +#if (AT_ROCM_ENABLED() && AT_MAGMA_ENABLED()) +#define HAVE_HIP 1 +#endif + +#define NVCC_FLAGS_EXTRA "-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_120,code=sm_120" diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAContext.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAContext.h new file mode 100644 index 0000000000000000000000000000000000000000..0cb024dd701b284502965cba681f1f9beb214592 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAContext.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +// Preserved for BC, as many files depend on these includes +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAContextLight.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAContextLight.h new file mode 100644 index 0000000000000000000000000000000000000000..65019bb6097c9b210fdd660975153b0dbad4b901 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAContextLight.h @@ -0,0 +1,102 @@ +#pragma once +// Light-weight version of CUDAContext.h with fewer transitive includes + +#include +#include + +#include +#include +#include + +// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also +// added bf16 support +#include + +#ifdef CUDART_VERSION +#include +#endif + +#if defined(USE_CUDSS) +#include +#endif + +#if defined(USE_ROCM) +#include +#endif + +#include +#include + +namespace c10 { +struct Allocator; +} + +namespace at::cuda { + +/* +A common CUDA interface for ATen. + +This interface is distinct from CUDAHooks, which defines an interface that links +to both CPU-only and CUDA builds. That interface is intended for runtime +dispatch and should be used from files that are included in both CPU-only and +CUDA builds. + +CUDAContext, on the other hand, should be preferred by files only included in +CUDA builds. It is intended to expose CUDA functionality in a consistent +manner. + +This means there is some overlap between the CUDAContext and CUDAHooks, but +the choice of which to use is simple: use CUDAContext when in a CUDA-only file, +use CUDAHooks otherwise. + +Note that CUDAContext simply defines an interface with no associated class. +It is expected that the modules whose functions compose this interface will +manage their own state. There is only a single CUDA context/state. +*/ + +/** + * DEPRECATED: use device_count() instead + */ +inline int64_t getNumGPUs() { + return c10::cuda::device_count(); +} + +/** + * CUDA is available if we compiled with CUDA, and there are one or more + * devices. If we compiled with CUDA but there is a driver problem, etc., + * this function will report CUDA is not available (rather than raise an error.) + */ +inline bool is_available() { + return c10::cuda::device_count() > 0; +} + +TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties(); + +TORCH_CUDA_CPP_API int warp_size(); + +TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device); + +TORCH_CUDA_CPP_API bool canDeviceAccessPeer( + c10::DeviceIndex device, + c10::DeviceIndex peer_device); + +TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); + +/* Handles */ +TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); +TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); +TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); + +TORCH_CUDA_CPP_API void clearCublasWorkspaces(); +TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublas_handle_stream_to_workspace(); +TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize(); + +#if defined(CUDART_VERSION) || defined(USE_ROCM) +TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); +#endif + +#if defined(USE_CUDSS) +TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle(); +#endif + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDADataType.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDADataType.h new file mode 100644 index 0000000000000000000000000000000000000000..6ee6346732fa985974c9f01e47f7cb5b04b5ff17 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDADataType.h @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include +#include + +namespace at::cuda { + +template +cudaDataType getCudaDataType() { + static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType."); + return {}; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_16F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_32F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_64F; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8U; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32I; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16BF; +} + +inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) { + switch (scalar_type) { + case c10::ScalarType::Byte: + return CUDA_R_8U; + case c10::ScalarType::Char: + return CUDA_R_8I; + case c10::ScalarType::Int: + return CUDA_R_32I; + case c10::ScalarType::Half: + return CUDA_R_16F; + case c10::ScalarType::Float: + return CUDA_R_32F; + case c10::ScalarType::Double: + return CUDA_R_64F; + case c10::ScalarType::ComplexHalf: + return CUDA_C_16F; + case c10::ScalarType::ComplexFloat: + return CUDA_C_32F; + case c10::ScalarType::ComplexDouble: + return CUDA_C_64F; + case c10::ScalarType::Short: + return CUDA_R_16I; + case c10::ScalarType::Long: + return CUDA_R_64I; + case c10::ScalarType::BFloat16: + return CUDA_R_16BF; +#if !defined(USE_ROCM) || ROCM_VERSION >= 60300 + case c10::ScalarType::Float8_e4m3fn: + return CUDA_R_8F_E4M3; + case c10::ScalarType::Float8_e5m2: + return CUDA_R_8F_E5M2; +#endif +#if defined(USE_ROCM) + case c10::ScalarType::Float8_e4m3fnuz: + return HIP_R_8F_E4M3_FNUZ; + case c10::ScalarType::Float8_e5m2fnuz: + return HIP_R_8F_E5M2_FNUZ; +#endif +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) + case c10::ScalarType::Float4_e2m1fn_x2: + return CUDA_R_4F_E2M1; +#endif + default: + TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.") + } +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDADevice.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDADevice.h new file mode 100644 index 0000000000000000000000000000000000000000..ba9a5eb849a091be6e86658a8c7af87a0a3fbb8f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDADevice.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +#include +#include + +namespace at::cuda { + +inline Device getDeviceFromPtr(void* ptr) { + cudaPointerAttributes attr{}; + + AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); + +#if !defined(USE_ROCM) + TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered, + "The specified pointer resides on host memory and is not registered with any CUDA device."); +#endif + + return {c10::DeviceType::CUDA, static_cast(attr.device)}; +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAEvent.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..dc8e2c4f70ff635213698de130ffad39aa9c95f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAEvent.h @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +/* +* `cudaEventExternal` is a torch-specific flag that is used to +* indicate that the CUDAEvent will be used only for synchronization +* with work outside of the cuda graph, rather than creation of +* cross-stream dependencies within a cuda graph. Resources: +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 +* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e +*/ +#define cudaEventExternal 0x08 + +namespace at::cuda { + +/* +* CUDAEvents are movable not copyable wrappers around CUDA's events. +* +* CUDAEvents are constructed lazily when first recorded unless it is +* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this +* device is acquired from the first recording stream. However, if reconstructed +* from a handle, the device should be explicitly specified; or if ipc_handle() is +* called before the event is ever recorded, it will use the current device. +* Later streams that record the event must match this device. +*/ +struct TORCH_CUDA_CPP_API CUDAEvent { + // Constructors + // Default value for `flags` is specified below - it's cudaEventDisableTiming + CUDAEvent() noexcept = default; + CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} + + CUDAEvent( + DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) { + CUDAGuard guard(device_index_); + + AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); + is_created_ = true; + } + + // Note: event destruction done on creating device to avoid creating a + // CUDA context on other devices. + ~CUDAEvent() { + try { + if (is_created_) { + CUDAGuard guard(device_index_); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast(event_)); + } + AT_CUDA_CHECK(cudaEventDestroy(event_)); + } + } catch (...) { /* No throw */ } + } + + CUDAEvent(const CUDAEvent&) = delete; + CUDAEvent& operator=(const CUDAEvent&) = delete; + + CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); } + CUDAEvent& operator=(CUDAEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator cudaEvent_t() const { return event(); } + + // Less than operator (to allow use in sets) + friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { + return left.event_ < right.event_; + } + + std::optional device() const { + if (is_created_) { + return at::Device(at::kCUDA, device_index_); + } else { + return {}; + } + } + + bool isCreated() const { return is_created_; } + DeviceIndex device_index() const {return device_index_;} + cudaEvent_t event() const { return event_; } + + // Note: cudaEventQuery can be safely called from any device + bool query() const { + if (!is_created_) { + return true; + } + + cudaError_t err = cudaEventQuery(event_); + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; + } + + void record() { record(getCurrentCUDAStream()); } + + void recordOnce(const CUDAStream& stream) { + if (!was_recorded_) record(stream); + } + + // Note: cudaEventRecord must be called on the same device as the event. + void record(const CUDAStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_, + " does not match recording stream's device ", stream.device_index(), "."); + CUDAGuard guard(device_index_); + +#ifndef USE_ROCM + // it is an error to use cudaEventRecordExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault; + AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); +#else + AT_CUDA_CHECK(cudaEventRecord(event_, stream)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record(at::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + was_recorded_ = true; + } + + // Note: cudaStreamWaitEvent must be called on the same device as the stream. + // The event has no actual GPU resources associated with it. + void block(const CUDAStream& stream) { + if (is_created_) { + CUDAGuard guard(stream.device_index()); +#ifndef USE_ROCM + // it is an error to use cudaEventWaitExternal when not doing stream capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault; + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); +#else + AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait(at::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream()) + ); + } + } + } + + // Note: cudaEventElapsedTime can be safely called from any device + float elapsed_time(const CUDAEvent& other) const { + TORCH_CHECK_VALUE( + !(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + + float time_ms = 0; + // We do not strictly have to set the device index to the same as our event, + // but if we don't and the current device is not initialized, it will + // create a new cuda context, which will consume a lot of memory. + CUDAGuard guard(device_index_); + // raise cudaErrorNotReady if either event is recorded but not yet completed + AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + // Note: cudaEventSynchronize can be safely called from any device + void synchronize() const { + if (is_created_) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast(event_)); + } + AT_CUDA_CHECK(cudaEventSynchronize(event_)); + } + } + + // Note: cudaIpcGetEventHandle must be called on the same device as the event + void ipc_handle(cudaIpcEventHandle_t * handle) { + if (!is_created_) { + // this CUDAEvent object was initially constructed from flags but event_ + // is not created yet. + createEvent(getCurrentCUDAStream().device_index()); + } + CUDAGuard guard(device_index_); + AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); + } + +private: + unsigned int flags_ = cudaEventDisableTiming; + bool is_created_ = false; + bool was_recorded_ = false; + bool external_ = false; + DeviceIndex device_index_ = -1; + cudaEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + external_ = (flags_ & cudaEventExternal) != 0; +#ifdef USE_ROCM + TORCH_CHECK(!external_, "External events are disallowed in rocm"); +#endif + flags_ &= ~cudaEventExternal; + device_index_ = device_index; + CUDAGuard guard(device_index_); + AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast(event_)); + } + is_created_ = true; + } + + void moveHelper(CUDAEvent&& other) { + std::swap(flags_, other.flags_); + std::swap(is_created_, other.is_created_); + std::swap(was_recorded_, other.was_recorded_); + std::swap(device_index_, other.device_index_); + std::swap(event_, other.event_); + } +}; + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..b0b77cb822a85f89442c9d77d00a8d4919b1a113 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h @@ -0,0 +1,180 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +namespace at { + +namespace cuda { +struct CUDAGraph; +} + +/** + * Note [CUDA Graph-safe RNG states] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * + * Strategy: + * ~~~~~~~~~ + * (It helps to look at + * cuda/detail/PhiloxCudaStateRaw.cuh and + * cuda/detail/UnpackRaw.cuh + * while you read this.) + * + * A CUDA graph containing multiple RNG ops behaves like a + * single giant kernel from the perspective of ops external + * to the graph. During graph capture, logic in CUDAGeneratorImpl + * records the total of all offset increments that occur in the + * graphed region, and records the final total as the offset for + * the entire graph. + * + * When the graph reruns, the logic that reruns it + * increments this device's CUDA generator's offset + * by that total. + * + * Meanwhile, within the graph, at capture time, instead of + * populating PhiloxCudaStates with the uint64_t offset pulled + * directly from the global state, PhiloxCudaState uses a pointer + * to a one-element stream-local int64_t device tensor + * holding an initial offset value, and a uint64_t holding an + * intra-graph offset. (The intra-graph offset starts from zero + * when capture begins.) In each consumer kernel, + * at::cuda::philox::unpack computes the offset to use for this kernel + * as intra-graph offset + *initial offset. + * + * When the graph reruns, the logic that reruns it first + * fill_s the initial offset tensor with this device's + * CUDA generator's current offset. + * + * The control flow above ensures graphed execution is bitwise + * identical to eager execution as long as RNG ops are enqueued + * from a single thread, even if RNG ops and graphs containing + * RNG ops are enqueued and run simultaneously on multiple streams. + * + * Usage: + * ~~~~~~ + * PhiloxCudaState in this file, and unpack() in + * cuda/CUDAGraphsUtils.cuh allow non-divergent use of + * CUDAGeneratorImpl whether graph capture is underway or not. + * + * Each PhiloxCudaState instance should be used for one and only one + * consumer kernel. + * + * Example (see e.g. native/cuda/Dropout.cu): + * + * #include + * #include + * + * __global__ void kernel(..., PhiloxCudaState philox_args) { + * auto seeds = at::cuda::philox::unpack(philox_args); + * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + * curandStatePhilox4_32_10_t state; + * curand_init(std::get<0>(seeds), // seed + * idx, // per-thread subsequence + * std::get<1>(seeds), // offset in subsequence + * &state); + * ... + * } + * + * host_caller(...) { + * PhiloxCudaState rng_engine_inputs; + * { + * // See Note [Acquire lock when using random generators] + * std::lock_guard lock(gen->mutex_); + * + * // gen could be HostState or DevState here! No divergent code needed! + * rng_engine_inputs = gen->philox_cuda_state(offset_increment); + * } + * kernel<<<...>>>(..., rng_engine_inputs); + * } + * + */ + +struct CUDAGeneratorState : public c10::intrusive_ptr_target { + uint64_t seed_; + uint64_t philox_offset_per_thread_; + uint32_t offset_intragraph_; + bool capturing_{}; + std::unordered_set registered_graphs_; + at::TensorBase seed_extragraph_{}; + at::TensorBase offset_extragraph_{}; + + CUDAGeneratorState( + uint64_t seed = default_rng_seed_val, + uint64_t philox_offset_per_thread = 0, + uint32_t offset_intragraph = 0) + : seed_(seed), + philox_offset_per_thread_(philox_offset_per_thread), + offset_intragraph_(offset_intragraph) {} + + void increase(uint64_t increment); + + void register_graph(cuda::CUDAGraph* graph); + void unregister_graph(cuda::CUDAGraph* graph); + + void capture_prologue(); + // capture_epilogue returns the wholegraph_increment + uint64_t capture_epilogue(); + void replay_prologue(uint64_t wholegraph_increment); + c10::intrusive_ptr clone(); +}; + +struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { + // Constructors + CUDAGeneratorImpl(DeviceIndex device_index = -1); + CUDAGeneratorImpl( + DeviceIndex device_index, + c10::intrusive_ptr state_); + ~CUDAGeneratorImpl() override = default; + + // CUDAGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void graphsafe_set_state( + const c10::intrusive_ptr& state) override; + c10::intrusive_ptr graphsafe_get_state() const override; + + void set_philox_offset_per_thread(uint64_t offset); + uint64_t philox_offset_per_thread() const; + + void register_graph(cuda::CUDAGraph* graph); + void unregister_graph(cuda::CUDAGraph* graph); + + // Generates a PhiloxCudaState with a specified increment, and increment + // current state + PhiloxCudaState philox_cuda_state(uint64_t increment); + + bool reset_rnn_state() { + return !no_reset_rnn_state_.test_and_set(); + } + + // Temporarily accommodates call sites that use philox_engine_inputs. + // Allows incremental refactor of call sites to use philox_cuda_state. + std::pair philox_engine_inputs(uint64_t increment); + + static c10::DeviceType device_type(); + + private: + CUDAGeneratorImpl* clone_impl() const override; + + c10::intrusive_ptr state_; + std::atomic_flag no_reset_rnn_state_{}; +}; + +namespace cuda::detail { + +TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator( + DeviceIndex device_index = -1); +TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1); + +} // namespace cuda::detail +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGraph.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGraph.h new file mode 100644 index 0000000000000000000000000000000000000000..c8cae16b624fe90bb8b98165276f0d00f56a5562 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGraph.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + +struct Generator; +struct CUDAGeneratorImpl; +struct CUDAGeneratorState; + +namespace cuda { + +// Standalone way to get a unique mempool id usable as a pool=... argument +// to CUDAGraph::capture_begin +TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle(); + +struct TORCH_CUDA_CPP_API CUDAGraph { + CUDAGraph(bool keep_graph=false); + ~CUDAGraph(); + + // See Note [Explicit Registration of Generators to the CUDA Graph] + void register_generator_state(c10::intrusive_ptr state); + void register_generator_state(const at::Generator& generator); + void capture_begin( + MempoolId_t pool = {0, 0}, + cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal); + void capture_end(); + void instantiate(); + void replay(); + void reset(); + MempoolId_t pool(); + void enable_debug_mode(); + void debug_dump(const std::string& debug_path); + cudaGraph_t raw_cuda_graph(); + + protected: + cudaGraph_t graph_ = nullptr; + cudaGraphExec_t graph_exec_ = nullptr; + + // internal states so reset() can do its best cleaning up + + // Set to true in capture_end if cudaStreamEndCapture succeeded + // Set back to false after instantiate() unless keep_graph=True or + // enable_debug_mode() was called on any CUDAGraph instance. + bool has_graph_ = false; + // Set to true in capture_end if cudaStreamEndCapture succeeded + bool capture_ended_ = false; + // Set to true in capture_end if cudaGraphInstantiate succeeded + bool has_graph_exec_ = false; + + // the ID assigned by cuda during graph capture, + // used to identify when a stream is participating in capture + CaptureId_t capture_id_ = -1; + + // uuid used to request a particular private mempool from CUDACachingAllocator. + // By default, this will be set to {id_, 0}. + // + // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_ + // will be set to the other graph's mempool_id_, and therefore share a mempool with the + // other graph. + // + // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(), + // it will share a mempool with any other captures that used "pool=handle". + // + // Sharing a mempool across graphs saves memory, and it's safe if you + // know you'll replay those graphs in the same order you captured them. + MempoolId_t mempool_id_; + + // Stream on which capture began + at::cuda::CUDAStream capture_stream_; + + // multiple generator states and their wholegraph_increments in this graph + // that are managed by the CUDA Graph + ska::flat_hash_map, uint64_t> + captured_generator_states_; + + // Device where capture occurred. Right now, for simplicity, we require all ops + // in a capture to run on the same device, but this is a limitation of CUDAGraph, + // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device + // captures if needed. + // init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor + static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1; + c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE}; + + bool keep_graph_; +}; + +} // namespace cuda +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d3a5b306eeea49b47d45ae5f28944b6dfd8ac4dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten. +// This file adds utils used by aten only. + +namespace at::cuda { + +using CaptureId_t = c10::cuda::CaptureId_t; +using CaptureStatus = c10::cuda::CaptureStatus; + +// Use this version where you don't want to create a CUDA context if none exists. +inline CaptureStatus currentStreamCaptureStatus() { + // don't create a context if we don't have to + if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) { + return c10::cuda::currentStreamCaptureStatusMayInitCtx(); + } else { + return CaptureStatus::None; + } +} + +inline void assertNotCapturing(const std::string& attempt) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK(status == CaptureStatus::None, + attempt, + " during CUDA graph capture. If you need this call to be captured, " + "please file an issue. " + "Current cudaStreamCaptureStatus: ", + status); +} + +inline void errorIfCapturingCudnnBenchmark(const std::string& version_specific) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK(status == CaptureStatus::None, + "Current cudaStreamCaptureStatus: ", + status, + "\nCapturing ", + version_specific, + "is prohibited. Possible causes of this error:\n" + "1. No warmup iterations occurred before capture.\n" + "2. The convolutions you're trying to capture use dynamic shapes, " + "in which case capturing them is generally prohibited."); +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparse.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparse.h new file mode 100644 index 0000000000000000000000000000000000000000..736fbe4ae50dac94f9022c9e827393d4a7e9cf5a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparse.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#if defined(USE_ROCM) +#include +#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch) +#endif + +// cuSparse Generic API added in CUDA 10.1 +// Windows support added in CUDA 11.0 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32))) +#define AT_USE_CUSPARSE_GENERIC_API() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_API() 0 +#endif + +// cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ + (CUSPARSE_VERSION < 12000) +#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1 +#else +#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0 +#endif + +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ + (CUSPARSE_VERSION >= 12000) +#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1 +#else +#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0 +#endif + +#if defined(USE_ROCM) +// hipSparse const API added in v2.4.0 +#if HIPSPARSE_VERSION >= 200400 +#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1 +#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_GENERIC_API() 1 +#else +#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1 +#define AT_USE_HIPSPARSE_GENERIC_API() 1 +#endif +#else // USE_ROCM +#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 +#define AT_USE_HIPSPARSE_GENERIC_API() 0 +#endif // USE_ROCM + +// cuSparse Generic API spsv function was added in CUDA 11.3.0 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500) +#define AT_USE_CUSPARSE_GENERIC_SPSV() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_SPSV() 0 +#endif + +// cuSparse Generic API spsm function was added in CUDA 11.3.1 +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600) +#define AT_USE_CUSPARSE_GENERIC_SPSM() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_SPSM() 0 +#endif + +// cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400) +#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400) +#define AT_USE_CUSPARSE_GENERIC_SDDMM() 1 +#else +#define AT_USE_CUSPARSE_GENERIC_SDDMM() 0 +#endif + +// BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0) +#if defined(CUDART_VERSION) || defined(USE_ROCM) +#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1 +#else +#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0 +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h new file mode 100644 index 0000000000000000000000000000000000000000..a098496491d155567d31454b35760c9eb3d941da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h @@ -0,0 +1,320 @@ +#pragma once + +/* + Provides a subset of cuSPARSE functions as templates: + + csrgeam2(...) + + where scalar_t is double, float, c10::complex or c10::complex. + The functions are available in at::cuda::sparse namespace. +*/ + +#include +#include + +// NOLINTBEGIN(misc-misplaced-const) +namespace at::cuda::sparse { + +#define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, int nnzA, \ + const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \ + const int *csrSortedColIndA, const scalar_t *beta, \ + const cusparseMatDescr_t descrB, int nnzB, \ + const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \ + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \ + const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \ + const int *csrSortedColIndC, size_t *pBufferSizeInBytes + +template +inline void csrgeam2_bufferSizeExt( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void csrgeam2_bufferSizeExt( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float)); +template <> +void csrgeam2_bufferSizeExt( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double)); +template <> +void csrgeam2_bufferSizeExt>( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex)); +template <> +void csrgeam2_bufferSizeExt>( + CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex)); + +#define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \ + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \ + int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \ + const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \ + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \ + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace + +template +inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) { + TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz( + handle, + m, + n, + descrA, + nnzA, + csrSortedRowPtrA, + csrSortedColIndA, + descrB, + nnzB, + csrSortedRowPtrB, + csrSortedColIndB, + descrC, + csrSortedRowPtrC, + nnzTotalDevHostPtr, + workspace)); +} + +#define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, int nnzA, \ + const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \ + const int *csrSortedColIndA, const scalar_t *beta, \ + const cusparseMatDescr_t descrB, int nnzB, \ + const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \ + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \ + scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \ + void *pBuffer + +template +inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::csrgeam2: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(float)); +template <> +void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(double)); +template <> +void csrgeam2>( + CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex)); +template <> +void csrgeam2>( + CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \ + int kb, int nnzb, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc + +template +inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrmm: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrmm(CUSPARSE_BSRMM_ARGTYPES(float)); +template <> +void bsrmm(CUSPARSE_BSRMM_ARGTYPES(double)); +template <> +void bsrmm>(CUSPARSE_BSRMM_ARGTYPES(c10::complex)); +template <> +void bsrmm>(CUSPARSE_BSRMM_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nb, int nnzb, \ + const scalar_t *alpha, const cusparseMatDescr_t descrA, \ + const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \ + int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y + +template +inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrmv: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrmv(CUSPARSE_BSRMV_ARGTYPES(float)); +template <> +void bsrmv(CUSPARSE_BSRMV_ARGTYPES(double)); +template <> +void bsrmv>(CUSPARSE_BSRMV_ARGTYPES(c10::complex)); +template <> +void bsrmv>(CUSPARSE_BSRMV_ARGTYPES(c10::complex)); + +#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() + +#define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nnzb, \ + const cusparseMatDescr_t descrA, scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsv2Info_t info, int *pBufferSizeInBytes + +template +inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float)); +template <> +void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double)); +template <> +void bsrsv2_bufferSize>( + CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex)); +template <> +void bsrsv2_bufferSize>( + CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nnzb, \ + const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsv2_analysis: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float)); +template <> +void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double)); +template <> +void bsrsv2_analysis>( + CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex)); +template <> +void bsrsv2_analysis>( + CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \ + const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \ + cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsv2_solve: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float)); +template <> +void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double)); +template <> +void bsrsv2_solve>( + CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex)); +template <> +void bsrsv2_solve>( + CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \ + int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsm2Info_t info, int *pBufferSizeInBytes + +template +inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float)); +template <> +void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double)); +template <> +void bsrsm2_bufferSize>( + CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex)); +template <> +void bsrsm2_bufferSize>( + CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \ + int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \ + const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \ + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsm2_analysis: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float)); +template <> +void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double)); +template <> +void bsrsm2_analysis>( + CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex)); +template <> +void bsrsm2_analysis>( + CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex)); + +#define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \ + cusparseHandle_t handle, cusparseDirection_t dirA, \ + cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \ + int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \ + const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \ + int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \ + scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer + +template +inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) { + TORCH_INTERNAL_ASSERT( + false, + "at::cuda::sparse::bsrsm2_solve: not implemented for ", + typeid(scalar_t).name()); +} + +template <> +void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float)); +template <> +void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double)); +template <> +void bsrsm2_solve>( + CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex)); +template <> +void bsrsm2_solve>( + CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex)); + +#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE + +} // namespace at::cuda::sparse +// NOLINTEND(misc-misplaced-const) diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h new file mode 100644 index 0000000000000000000000000000000000000000..7fc482f2a3fbd024e9c80afd0af8c19a13261de5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h @@ -0,0 +1,288 @@ +#pragma once + +#include +#include +#include + +#include + +#if defined(USE_ROCM) +#include +#endif + +namespace at::cuda::sparse { + +template +struct CuSparseDescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDASPARSE_CHECK(destructor(x)); + } + } +}; + +template +class CuSparseDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() +template +struct ConstCuSparseDescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDASPARSE_CHECK(destructor(x)); + } + } +}; + +template +class ConstCuSparseDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; +#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS + +#if defined(USE_ROCM) +using cusparseMatDescr = std::remove_pointer_t; +using cusparseDnMatDescr = std::remove_pointer_t; +using cusparseDnVecDescr = std::remove_pointer_t; +using cusparseSpMatDescr = std::remove_pointer_t; +using cusparseSpMatDescr = std::remove_pointer_t; +using cusparseSpGEMMDescr = std::remove_pointer_t; +#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() +using bsrsv2Info = std::remove_pointer_t; +using bsrsm2Info = std::remove_pointer_t; +#endif +#endif + +// NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced +// API for const descriptors +cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr); + +class TORCH_CUDA_CPP_API CuSparseMatDescriptor + : public CuSparseDescriptor { + public: + CuSparseMatDescriptor() { + cusparseMatDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } + + CuSparseMatDescriptor(bool upper, bool unit) { + cusparseFillMode_t fill_mode = + upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; + cusparseDiagType_t diag_type = + unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; + cusparseMatDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor)); + TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode)); + TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type)); + descriptor_.reset(raw_descriptor); + } +}; + +#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() + +class TORCH_CUDA_CPP_API CuSparseBsrsv2Info + : public CuSparseDescriptor { + public: + CuSparseBsrsv2Info() { + bsrsv2Info_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; + +class TORCH_CUDA_CPP_API CuSparseBsrsm2Info + : public CuSparseDescriptor { + public: + CuSparseBsrsm2Info() { + bsrsm2Info_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; + +#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE + +#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() + +cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type); + +#if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() +class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor + : public CuSparseDescriptor { + public: + explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1); +}; + +class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor + : public CuSparseDescriptor { + public: + explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1); + cusparseDnMatDescr* unsafe_mutable_descriptor() const { + return const_cast(descriptor()); + } + cusparseDnMatDescr* unsafe_mutable_descriptor() { + return const_cast(descriptor()); + } +}; + +class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor + : public CuSparseDescriptor { + public: + explicit CuSparseDnVecDescriptor(const Tensor& input); +}; + +class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor + : public CuSparseDescriptor {}; + +#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() + class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor + : public ConstCuSparseDescriptor< + cusparseDnMatDescr, + &cusparseDestroyDnMat> { + public: + explicit CuSparseDnMatDescriptor( + const Tensor& input, + int64_t batch_offset = -1); + }; + + class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor + : public ConstCuSparseDescriptor< + const cusparseDnMatDescr, + &destroyConstDnMat> { + public: + explicit CuSparseConstDnMatDescriptor( + const Tensor& input, + int64_t batch_offset = -1); + cusparseDnMatDescr* unsafe_mutable_descriptor() const { + return const_cast(descriptor()); + } + cusparseDnMatDescr* unsafe_mutable_descriptor() { + return const_cast(descriptor()); + } + }; + + class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor + : public ConstCuSparseDescriptor< + cusparseDnVecDescr, + &cusparseDestroyDnVec> { + public: + explicit CuSparseDnVecDescriptor(const Tensor& input); + }; + + class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor + : public ConstCuSparseDescriptor< + cusparseSpMatDescr, + &cusparseDestroySpMat> {}; +#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS() + +class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor + : public CuSparseSpMatDescriptor { + public: + explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1); + + std::tuple get_size() { + int64_t rows = 0, cols = 0, nnz = 0; + TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize( + this->descriptor(), + &rows, + &cols, + &nnz)); + return std::make_tuple(rows, cols, nnz); + } + + void set_tensor(const Tensor& input) { + auto crow_indices = input.crow_indices(); + auto col_indices = input.col_indices(); + auto values = input.values(); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous()); + TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers( + this->descriptor(), + crow_indices.data_ptr(), + col_indices.data_ptr(), + values.data_ptr())); + } + +#if AT_USE_CUSPARSE_GENERIC_SPSV() + void set_mat_fill_mode(bool upper) { + cusparseFillMode_t fill_mode = + upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER; + TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( + this->descriptor(), + CUSPARSE_SPMAT_FILL_MODE, + &fill_mode, + sizeof(fill_mode))); + } + + void set_mat_diag_type(bool unit) { + cusparseDiagType_t diag_type = + unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT; + TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute( + this->descriptor(), + CUSPARSE_SPMAT_DIAG_TYPE, + &diag_type, + sizeof(diag_type))); + } +#endif +}; + +#if AT_USE_CUSPARSE_GENERIC_SPSV() +class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor + : public CuSparseDescriptor { + public: + CuSparseSpSVDescriptor() { + cusparseSpSVDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; +#endif + +#if AT_USE_CUSPARSE_GENERIC_SPSM() +class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor + : public CuSparseDescriptor { + public: + CuSparseSpSMDescriptor() { + cusparseSpSMDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; +#endif + +class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor + : public CuSparseDescriptor { + public: + CuSparseSpGEMMDescriptor() { + cusparseSpGEMMDescr_t raw_descriptor = nullptr; + TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } +}; + +#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API() + +} // namespace at::cuda::sparse diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e4e89ea1cdb77da1d7866ffe99c64dabfd735d27 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace at { +template <> +inline __half* Tensor::data() const { + return reinterpret_cast<__half*>(data()); +} +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..d5f65dd6a572f9cab15dbae9983df8037b724a99 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CUDAUtils.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at::cuda { + +// Check if every tensor in a list of tensors matches the current +// device. +inline bool check_device(ArrayRef ts) { + if (ts.empty()) { + return true; + } + Device curDevice = Device(kCUDA, current_device()); + for (const Tensor& t : ts) { + if (t.device() != curDevice) return false; + } + return true; +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..b9486314b1c21b7940a4d1040360fe37a83737cd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::cuda { + +// +// A caching allocator for CUDA host allocations (pinned memory). +// +// This provides a drop-in replacement for THCudaHostAllocator, which re-uses +// freed pinned (page-locked) memory allocations. This avoids device +// synchronizations due to cudaFreeHost calls. +// +// To ensure correct behavior, THCCachingHostAllocator_recordEvent must be +// called anytime a pointer from this allocator is used in a cudaMemcpyAsync +// call between host and device, and passed the corresponding context from the +// allocation. This is currently invoked by at::native::copy_kernel_cuda. +// +C10_DEPRECATED_MESSAGE( + "at::cuda::getCachingHostAllocator() is deprecated. Please use at::getHostAllocator(at::kCUDA) instead.") +inline TORCH_CUDA_CPP_API at::HostAllocator* getCachingHostAllocator() { + return at::getHostAllocator(at::kCUDA); +} + +// Records an event in the specified stream. The allocation corresponding to the +// input `ptr`/`ctx` will not be re-used until the event has occurred. +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_recordEvent(...) is deprecated. Please use at::getHostAllocator(at::kCUDA)->record_event(...) instead.") +inline TORCH_CUDA_CPP_API bool CachingHostAllocator_recordEvent( + void* ptr, + void* ctx, + c10::cuda::CUDAStream stream) { + return getHostAllocator(at::kCUDA)->record_event(ptr, ctx, stream.unwrap()); +} + +// Releases cached pinned memory allocations via cudaHostFree +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_emptyCache() is deprecated. Please use at::getHostAllocator(at::kCUDA)->empty_cache() instead.") +inline TORCH_CUDA_CPP_API void CachingHostAllocator_emptyCache() { + getHostAllocator(at::kCUDA)->empty_cache(); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::HostAlloc(...) is deprecated. Please use at::getHostAllocator(at::kCUDA)->allocate(...) instead.") +inline TORCH_CUDA_CPP_API at::DataPtr HostAlloc(size_t size) { + return getHostAllocator(at::kCUDA)->allocate(size); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_getStats() is deprecated. Please use at::getHostAllocator(at::kCUDA)->get_stats() instead.") +inline TORCH_CUDA_CPP_API at::HostStats CachingHostAllocator_getStats() { + return getHostAllocator(at::kCUDA)->get_stats(); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_resetAccumulatedStats() is deprecated. Please use at::getHostAllocator(at::kCUDA)->reset_accumulated_stats() instead.") +inline TORCH_CUDA_CPP_API void CachingHostAllocator_resetAccumulatedStats() { + getHostAllocator(at::kCUDA)->reset_accumulated_stats(); +} + +C10_DEPRECATED_MESSAGE( + "at::cuda::CachingHostAllocator_resetPeakStats() is deprecated. Please use at::getHostAllocator(at::kCUDA)->reset_peak_stats() instead.") +inline TORCH_CUDA_CPP_API void CachingHostAllocator_resetPeakStats() { + getHostAllocator(at::kCUDA)->reset_peak_stats(); +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c0a2fc47c0069bed4e15beec249162a556f7dc3d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include + +__device__ __forceinline__ unsigned int ACTIVE_MASK() +{ +#if !defined(USE_ROCM) + return __activemask(); +#else +// will be ignored anyway + return 0xffffffff; +#endif +} + +__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) { +#if !defined(USE_ROCM) + return __syncwarp(mask); +#endif +} + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate) +{ +return __ballot(predicate); +} +#else +__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __ballot_sync(mask, predicate); +#else + return __ballot(predicate); +#endif +} +#endif + +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_sync(mask, value, srcLane, width); +#else + return __shfl(value, srcLane, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_up_sync(mask, value, delta, width); +#else + return __shfl_up(value, delta, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +#if defined(USE_ROCM) +template<> +__device__ __forceinline__ int64_t WARP_SHFL_DOWN(int64_t value, unsigned int delta, int width , unsigned int mask) +{ + //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ + int2 a = *reinterpret_cast(&value); + a.x = __shfl_down(a.x, delta); + a.y = __shfl_down(a.y, delta); + return *reinterpret_cast(&a); +} +#endif + +template<> +__device__ __forceinline__ c10::Half WARP_SHFL_DOWN(c10::Half value, unsigned int delta, int width, unsigned int mask) +{ + return c10::Half(WARP_SHFL_DOWN(value.x, delta, width, mask), c10::Half::from_bits_t{}); +} + +template +__device__ __forceinline__ c10::complex WARP_SHFL_DOWN(c10::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if !defined(USE_ROCM) + return c10::complex( + __shfl_down_sync(mask, value.real_, delta, width), + __shfl_down_sync(mask, value.imag_, delta, width)); +#else + return c10::complex( + __shfl_down(value.real_, delta, width), + __shfl_down(value.imag_, delta, width)); +#endif +} + +/** + * For CC 3.5+, perform a load using __ldg + */ +template +__device__ __forceinline__ T doLdg(const T* p) { +#if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM) + return __ldg(p); +#else + return *p; +#endif +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/EmptyTensor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..2fd88a94b75d2ca72133c253eca800295a1771aa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/EmptyTensor.h @@ -0,0 +1,44 @@ +#pragma once +#include + +namespace at::detail { + +TORCH_CUDA_CPP_API TensorBase empty_cuda( + IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt); + +TORCH_CUDA_CPP_API TensorBase empty_cuda( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_CUDA_CPP_API TensorBase empty_cuda( + IntArrayRef size, + const TensorOptions &options); + +TORCH_CUDA_CPP_API TensorBase empty_strided_cuda( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + std::optional device_opt); + +TORCH_CUDA_CPP_API TensorBase empty_strided_cuda( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TORCH_CUDA_CPP_API TensorBase empty_strided_cuda( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions &options); + + +} // namespace at::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Exceptions.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..cec013f006dbc2d7e8b87a30716247cfc8849366 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Exceptions.h @@ -0,0 +1,230 @@ +#pragma once + +#include +#include +#include + +#if !defined(USE_ROCM) +#include +#else +#include +#endif + +#if defined(USE_CUDSS) +#include +#endif + +#include +#include +#include + + +namespace c10 { + +class CuDNNError : public c10::Error { + using Error::Error; +}; + +} // namespace c10 + +#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \ + do { \ + auto error_object = EXPR; \ + if (!error_object.is_good()) { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN Frontend error: ", error_object.get_message()); \ + } \ + } while (0) \ + +#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__) + +// See Note [CHECK macro] +#define AT_CUDNN_CHECK(EXPR, ...) \ + do { \ + cudnnStatus_t status = EXPR; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", \ + cudnnGetErrorString(status), \ + ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \ + } else { \ + TORCH_CHECK_WITH(CuDNNError, false, \ + "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \ + } \ + } \ + } while (0) + +namespace at::cuda::blas { +C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); +} // namespace at::cuda::blas + +#define TORCH_CUDABLAS_CHECK(EXPR) \ + do { \ + cublasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \ + "CUDA error: ", \ + at::cuda::blas::_cublasGetErrorEnum(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +const char *cusparseGetErrorString(cusparseStatus_t status); + +#define TORCH_CUDASPARSE_CHECK(EXPR) \ + do { \ + cusparseStatus_t __err = EXPR; \ + TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \ + "CUDA error: ", \ + cusparseGetErrorString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +#if defined(USE_CUDSS) +namespace at::cuda::cudss { +C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error); +} // namespace at::cuda::solver + +#define TORCH_CUDSS_CHECK(EXPR) \ + do { \ + cudssStatus_t __err = EXPR; \ + if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \ + TORCH_CHECK_LINALG( \ + false, \ + "cudss error: ", \ + at::cuda::cudss::cudssGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN. ");\ + } else { \ + TORCH_CHECK( \ + __err == CUDSS_STATUS_SUCCESS, \ + "cudss error: ", \ + at::cuda::cudss::cudssGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. "); \ + } \ + } while (0) +#else +#define TORCH_CUDSS_CHECK(EXPR) EXPR +#endif + +namespace at::cuda::solver { +#if !defined(USE_ROCM) + +C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); + +constexpr const char* _cusolver_backend_suggestion = \ + "If you keep seeing this error, you may use " \ + "`torch.backends.cuda.preferred_linalg_library()` to try " \ + "linear algebra operators with other supported backends. " \ + "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; + +// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. +#define TORCH_CUSOLVER_CHECK(EXPR) \ + do { \ + cusolverStatus_t __err = EXPR; \ + if (__err == CUSOLVER_STATUS_INVALID_VALUE) { \ + TORCH_CHECK_LINALG( \ + false, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ + } else { \ + TORCH_CHECK( \ + __err == CUSOLVER_STATUS_SUCCESS, \ + "cusolver error: ", \ + at::cuda::solver::cusolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. ", \ + at::cuda::solver::_cusolver_backend_suggestion); \ + } \ + } while (0) + +#else // defined(USE_ROCM) + +C10_EXPORT const char* hipsolverGetErrorMessage(hipsolverStatus_t status); + +constexpr const char* _hipsolver_backend_suggestion = \ + "If you keep seeing this error, you may use " \ + "`torch.backends.cuda.preferred_linalg_library()` to try " \ + "linear algebra operators with other supported backends. " \ + "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; + +#define TORCH_CUSOLVER_CHECK(EXPR) \ + do { \ + hipsolverStatus_t __err = EXPR; \ + if (__err == HIPSOLVER_STATUS_INVALID_VALUE) { \ + TORCH_CHECK_LINALG( \ + false, \ + "hipsolver error: ", \ + at::cuda::solver::hipsolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`", \ + ". This error may appear if the input matrix contains NaN. ", \ + at::cuda::solver::_hipsolver_backend_suggestion); \ + } else { \ + TORCH_CHECK( \ + __err == HIPSOLVER_STATUS_SUCCESS, \ + "hipsolver error: ", \ + at::cuda::solver::hipsolverGetErrorMessage(__err), \ + ", when calling `" #EXPR "`. ", \ + at::cuda::solver::_hipsolver_backend_suggestion); \ + } \ + } while (0) +#endif +} // namespace at::cuda::solver + +#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR) + +// For CUDA Driver API +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#if !defined(USE_ROCM) + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + [[maybe_unused]] CUresult get_error_str_err = \ + at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + TORCH_CHECK(false, "CUDA driver error: unknown error"); \ + } else { \ + TORCH_CHECK(false, "CUDA driver error: ", err_str); \ + } \ + } \ + } while (0) + +#else + +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + TORCH_CHECK(false, "CUDA driver error: ", static_cast(__err)); \ + } \ + } while (0) + +#endif + +// For CUDA NVRTC +// +// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, +// incorrectly produces the error string "NVRTC unknown error." +// The following maps it correctly. +// +// This is here instead of in c10 because NVRTC is loaded dynamically via a stub +// in ATen, and we need to use its nvrtcGetErrorString. +// See NOTE [ USE OF NVRTC AND DRIVER API ]. +#define AT_CUDA_NVRTC_CHECK(EXPR) \ + do { \ + nvrtcResult __err = EXPR; \ + if (__err != NVRTC_SUCCESS) { \ + if (static_cast(__err) != 7) { \ + TORCH_CHECK(false, "CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ + } else { \ + TORCH_CHECK(false, "CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ + } \ + } \ + } while (0) diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/NumericLimits.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/NumericLimits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7081e94837caa7d5050128e0bfe19aa67f93cd39 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/NumericLimits.cuh @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include +#include + +// NumericLimits.cuh is a holder for numeric limits definitions of commonly used +// types. This header is very specific to ROCm HIP and may be removed in the future. +// This header is derived from the legacy THCNumerics.cuh. + +// The lower_bound and upper_bound constants are same as lowest and max for +// integral types, but are -inf and +inf for floating point types. They are +// useful in implementing min, max, etc. + +namespace at { + +template +struct numeric_limits { +}; + +// WARNING: the following at::numeric_limits definitions are there only to support +// HIP compilation for the moment. Use std::numeric_limits if you are not +// compiling for ROCm. +// from @colesbury: "The functions on numeric_limits aren't marked with +// __device__ which is why they don't work with ROCm. CUDA allows them +// because they're constexpr." + +namespace { + // ROCm doesn't like INFINITY too. + constexpr double inf = INFINITY; +} + +template <> +struct numeric_limits { + static inline __host__ __device__ bool lowest() { return false; } + static inline __host__ __device__ bool max() { return true; } + static inline __host__ __device__ bool lower_bound() { return false; } + static inline __host__ __device__ bool upper_bound() { return true; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ uint8_t lowest() { return 0; } + static inline __host__ __device__ uint8_t max() { return UINT8_MAX; } + static inline __host__ __device__ uint8_t lower_bound() { return 0; } + static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int8_t lowest() { return INT8_MIN; } + static inline __host__ __device__ int8_t max() { return INT8_MAX; } + static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; } + static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } + static inline __host__ __device__ int16_t max() { return INT16_MAX; } + static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; } + static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } + static inline __host__ __device__ int32_t max() { return INT32_MAX; } + static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; } + static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } +}; + +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ int64_t lowest() { return _I64_MIN; } + static inline __host__ __device__ int64_t max() { return _I64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; } +#else + static inline __host__ __device__ int64_t lowest() { return INT64_MIN; } + static inline __host__ __device__ int64_t max() { return INT64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; } +#endif +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); } + static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ float lowest() { return -FLT_MAX; } + static inline __host__ __device__ float max() { return FLT_MAX; } + static inline __host__ __device__ float lower_bound() { return -static_cast(inf); } + static inline __host__ __device__ float upper_bound() { return static_cast(inf); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ double lowest() { return -DBL_MAX; } + static inline __host__ __device__ double max() { return DBL_MAX; } + static inline __host__ __device__ double lower_bound() { return -inf; } + static inline __host__ __device__ double upper_bound() { return inf; } +}; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h new file mode 100644 index 0000000000000000000000000000000000000000..5b63a855f3f46e8738f01749120f508d4fd7902e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h @@ -0,0 +1,12 @@ +#include +#include +#include + +namespace at::cuda { +namespace detail { +void init_p2p_access_cache(int64_t num_devices); +} + +TORCH_CUDA_CPP_API bool get_p2p_access(c10::DeviceIndex source_dev, c10::DeviceIndex dest_dev); + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h new file mode 100644 index 0000000000000000000000000000000000000000..dc4b131aa9c4b468092c5819f2c4aa321be8ca30 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0b161b70d96f428512cf6c26c3acc1654da74bf0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..079013e56055a2769b84884ce4c02eb7ff45e1b5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace at::cuda { + +inline TORCH_CUDA_CPP_API at::HostAllocator* getPinnedMemoryAllocator() { + return at::getHostAllocator(at::kCUDA); +} +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ScanUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ScanUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..fd8b9e91d48815761e7234b0d6fbcaab7f62fcee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ScanUtils.cuh @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include + +// Collection of in-kernel scan / prefix sum utilities + +namespace at::cuda { + +// Inclusive prefix sum for binary vars using intra-warp voting + +// shared memory +template +__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) { + // Within-warp, we use warp voting. +#if defined (USE_ROCM) + unsigned long long int vote = WARP_BALLOT(in); + T index = __popcll(getLaneMaskLe() & vote); + T carry = __popcll(vote); +#else + T vote = WARP_BALLOT(in); + T index = __popc(getLaneMaskLe() & vote); + T carry = __popc(vote); +#endif + + int warp = threadIdx.x / C10_WARP_SIZE; + + // Per each warp, write out a value + if (getLaneId() == 0) { + smem[warp] = carry; + } + + __syncthreads(); + + // Sum across warps in one thread. This appears to be faster than a + // warp shuffle scan for CC 3.0+ + if (threadIdx.x == 0) { + int current = 0; + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { + T v = smem[i]; + smem[i] = binop(smem[i], current); + current = binop(current, v); + } + } + + __syncthreads(); + + // load the carry from the preceding warp + if (warp >= 1) { + index = binop(index, smem[warp - 1]); + } + + *out = index; + + if (KillWARDependency) { + __syncthreads(); + } +} + +// Exclusive prefix sum for binary vars using intra-warp voting + +// shared memory +template +__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) { + inclusiveBinaryPrefixScan(smem, in, out, binop); + + // Inclusive to exclusive + *out -= (T) in; + + // The outgoing carry for all threads is the last warp's sum + *carry = smem[at::ceil_div(blockDim.x, C10_WARP_SIZE) - 1]; + + if (KillWARDependency) { + __syncthreads(); + } +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Sleep.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Sleep.h new file mode 100644 index 0000000000000000000000000000000000000000..ef5e83a832f739e19f13837500824c984013812e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/Sleep.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace at::cuda { + +// enqueues a kernel that spins for the specified number of cycles +TORCH_CUDA_CU_API void sleep(int64_t cycles); + +// flushes instruction cache for ROCm; no-op for CUDA +TORCH_CUDA_CU_API void flush_icache(); + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ThrustAllocator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ThrustAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..85783c303534ec9f6190c2d6bb44a8faafd680f7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/ThrustAllocator.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace at::cuda { + +/// Allocator for Thrust to re-route its internal device allocations +/// to the THC allocator +class ThrustAllocator { +public: + typedef char value_type; + + char* allocate(std::ptrdiff_t size) { + return static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(size)); + } + + void deallocate(char* p, size_t size) { + c10::cuda::CUDACachingAllocator::raw_delete(p); + } +}; + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub-RadixSortPairs.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub-RadixSortPairs.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bd40deb4125b20f58efad860eaff25e8717cf82d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub-RadixSortPairs.cuh @@ -0,0 +1,74 @@ +#pragma once + +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at::cuda::cub::detail { + +template +void radix_sort_pairs_impl( + const key_t* keys_in, + key_t* keys_out, + const OpaqueType* values_in, + OpaqueType* values_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(n * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } +} + +#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ + template void radix_sort_pairs_impl( \ + const key_t* keys_in, \ + key_t* keys_out, \ + const OpaqueType* values_in, \ + OpaqueType* values_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); + +#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ + AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e1d452bac4a6032d480cf7256131a2bff7b13b52 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub.cuh @@ -0,0 +1,625 @@ +#pragma once +#include + +#include +#include +#include +#include + +#include +#include + +#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE() + +#include + +#else + +// include cub in a safe manner, see: +// https://github.com/pytorch/pytorch/pull/55292 +#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings +#undef CUB_NS_PREFIX +#undef CUB_NS_QUALIFIER +#define CUB_NS_PREFIX namespace at_cuda_detail { +#define CUB_NS_POSTFIX } +#define CUB_NS_QUALIFIER ::at_cuda_detail::cub +#include +#undef CUB_NS_POSTFIX +#undef CUB_NS_PREFIX +#undef CUB_NS_QUALIFIER + +#endif + +#include +#include +#include + +// handle the temporary storage and 'twice' calls for cub API +#define CUB_WRAPPER(func, ...) do { \ + size_t temp_storage_bytes = 0; \ + AT_CUDA_CHECK(func(nullptr, temp_storage_bytes, __VA_ARGS__)); \ + auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \ + auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ + AT_CUDA_CHECK(func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__));\ +} while (false) + +#ifdef USE_ROCM +#define NO_ROCM(x) +#define ROCM_HIPCUB(x) ::hipcub +#else +#define NO_ROCM(x) x +#define ROCM_HIPCUB(x) x +#endif + +#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) + +#if !defined(USE_ROCM) +namespace at_cuda_detail { +#endif + +// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 + +template <> +struct ROCM_HIPCUB(cub)::FpLimits +{ + static __host__ __device__ __forceinline__ c10::BFloat16 Max() { + unsigned short max_word = 0x7F7F; + return reinterpret_cast(max_word); + } + + static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() { + unsigned short lowest_word = 0xFF7F; + return reinterpret_cast(lowest_word); + } +}; + +template <> +struct ROCM_HIPCUB(cub)::NumericTraits: + ROCM_HIPCUB(cub)::BaseTraits {}; + +#if !defined(USE_ROCM) +} // namespace at_cuda_detail +#endif + +#endif + +#if !defined(USE_ROCM) +namespace at::native { +namespace cub = ::at_cuda_detail::cub; +} // namespace at::native +#endif + +namespace at::cuda::cub { + +namespace detail { + +template +struct cuda_type { + using type = T; +}; +template<> +struct cuda_type { + using type = __half; +}; + +#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16() + +template<> +struct cuda_type { + using type = __nv_bfloat16; +}; + +#elif defined(USE_ROCM) + +template<> +struct cuda_type { + using type = hip_bfloat16; +}; + +#endif + +} // namespace detail + +template +inline void segmented_sort_pairs( + const key_t *keys_in, key_t *keys_out, + const value_t *values_in, value_t *values_out, + int64_t num_elements, int64_t num_segments, + OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets, + bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8 +) { + TORCH_CHECK(num_elements <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + TORCH_CHECK(num_segments <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(num_elements * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_ *keys_in_ = reinterpret_cast(keys_in); + key_t_ *keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending, + keys_in_, keys_out_, values_in, values_out, + num_elements, num_segments, begin_offsets, end_offsets, + begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs, + keys_in_, keys_out_, values_in, values_out, + num_elements, num_segments, begin_offsets, end_offsets, + begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); + } +} + +#if CUB_SUPPORTS_UNIQUE_BY_KEY() +template +inline void unique_by_key( + KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, + ValuesOutputIteratorT values_out, + NumSelectedIteratorT num_selected, int64_t num_input_items) +{ + // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed. + using KeyT = typename std::iterator_traits::value_type; + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT)); + auto keys_out_ = static_cast(keys_out_owner.get()); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, + keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); +} +#endif + +namespace impl { + +template +C10_LAUNCH_BOUNDS_1(1) +__global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){ + // NOTE: out here not the final scan output, but an intermediate of the accumulation type. + using acc_t = typename std::iterator_traits::value_type; + *out = scan_op(static_cast(*a), static_cast(*b)); +} + +#if !CUB_SUPPORTS_FUTURE_VALUE() +template +struct chained_iterator { + using iterator_category = std::random_access_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = ValueT; + using pointer = ValueT*; + using reference = ValueT&; + + InputIteratorT iter; + ValueT *first; + difference_type offset = 0; + + __device__ ValueT operator[](difference_type i) { + i += offset; + if (i == 0) { + return *first; + } else { + return ValueT(iter[i - 1]); + } + } + __device__ chained_iterator operator+(difference_type i) { + return chained_iterator{iter, first, i}; + } + __device__ ValueT operator*() { + return (*this)[0]; + } +}; +#endif + +// even though cub is supposed to support tensors with int_max elements, in reality it doesn't, +// so split at int_max/2 +constexpr int max_cub_size = std::numeric_limits::max() / 2 + 1; // 2**30 +} + +// non synchronizing cub call +// even though cub is supposed to support tensors with int_max elements, in reality it doesn't, +// so split at int_max/2 +template +inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) { +#if defined(USE_ROCM) + //For ROCm, use hipCUB chained iterators + CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan, + input, + output, + scan_op, + num_items, + at::cuda::getCurrentCUDAStream()); + C10_HIP_KERNEL_LAUNCH_CHECK(); +#else + // non synchronizing cub call + // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, + // so split at int_max/2 + int size_cub = std::min(num_items, max_cub_size); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, + input, + output, + scan_op, + size_cub, + at::cuda::getCurrentCUDAStream()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + using input_t = typename std::iterator_traits::value_type; + for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) { + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr first_elem = allocator->allocate(sizeof(input_t)); + auto first_elem_ptr = reinterpret_cast(first_elem.get()); + + size_cub = std::min(num_items - i, max_cub_size); + impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + output + i - 1, + input + i, + first_elem_ptr, + scan_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#if !CUB_SUPPORTS_FUTURE_VALUE() + using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator; + using tuple = typename ArgIndexInputIterator::value_type; + auto input_iter_transform = [=] __device__ (const tuple &x)->input_t { + if (x.key == 0) { + return *first_elem_ptr; + } else { + return x.value; + } + }; + auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator( + ArgIndexInputIterator(input + i), input_iter_transform); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, + input_, + output + i, + scan_op, + size_cub, + at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, + input + i + 1, + output + i, + scan_op, + ::at_cuda_detail::cub::FutureValue(first_elem_ptr), + size_cub, + at::cuda::getCurrentCUDAStream()); +#endif + } +#endif +} + +# if defined(CUDA_VERSION) || defined(USE_ROCM) + +template +struct BlockPrefixCallbackOp +{ + public: + T running_total; + + __host__ __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __host__ __device__ T operator()(T block_aggregate) + { + T old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template +__global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem, int iters_per_cta) { + int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x; + int64_t remaining = nelem - offset; + if (remaining <= 0) { + return; + } + + d_in += offset; + d_out += offset; + + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + + // Specialize BlockStore type for our thread block (uses warp-striped loads for coalescing, then transposes in shared + // memory to a blocked arrangement) + using BlockStoreT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockStore; + + // Specialize BlockScan type for our thread block + using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan; + using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; + + + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockStoreT::TempStorage store; + typename BlockScanT::TempStorage scan; + typename BlockReduceT::TempStorage reduce; + } temp_storage; + + // load agg and reduce my starting value + T agg_data; + agg_data = threadIdx.x >= blockIdx.x ? T(0) : agg[threadIdx.x]; + // if there are fewer threads than previous values to be read, + // read another value + if (threadIdx.x + blockDim.x < blockIdx.x) { + agg_data += agg[threadIdx.x + blockDim.x]; + } + T aggregate = BlockReduceT(temp_storage.reduce).Sum(agg_data); + __syncthreads(); + BlockPrefixCallbackOp prefix_op(aggregate); + + + // Per-thread tile data + T data[ITEMS_PER_THREAD]; + + for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockLoadT(temp_storage.load).Load(d_in, data); + } else { + #pragma unroll + for (int j=0; j= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockStoreT(temp_storage.store).Store(d_out, data); + } else { + BlockStoreT(temp_storage.store).Store(d_out, data, remaining); + } + d_in += BLOCK_THREADS * ITEMS_PER_THREAD; + d_out += BLOCK_THREADS * ITEMS_PER_THREAD; + remaining -= BLOCK_THREADS * ITEMS_PER_THREAD; + if (remaining <= 0) return; + __syncthreads(); + } + +} + +template +struct TransformFunctor { + __device__ aggT operator()(T value) const { + if constexpr (!nonzero) { + return value; + } else { + return (value != T(0)) ? 1 : 0; + } + } +}; + +template +__global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){ + int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x; + int64_t remaining = nelem - offset; + if (remaining <= 0) { + return; + } + d_in += offset; + + using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad; + using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce; + // Shared memory + __shared__ union TempStorage + { + typename BlockLoadT::TempStorage load; + typename BlockReduceT::TempStorage reduce; + } temp_storage; + aggT data[ITEMS_PER_THREAD]; + aggT agg_val = 0; + TransformFunctor transform_functor; + auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>(d_in, transform_functor); + for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { + BlockLoadT(temp_storage.load).Load(iter_in, data); + __syncthreads(); + agg_val += BlockReduceT(temp_storage.reduce).Sum(data); + + } else { + BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0)); + __syncthreads(); + agg_val += BlockReduceT(temp_storage.reduce).Sum(data); + } + iter_in += BLOCK_THREADS * ITEMS_PER_THREAD; + remaining -= BLOCK_THREADS * ITEMS_PER_THREAD; + if (remaining <= 0) { + // for nonzeros we need to write out last blocks + // accumulated value to be able to compute + // total number of nonzeros + if (nonzero && threadIdx.x == 0) { + agg[blockIdx.x] = agg_val; + } + return; + } + __syncthreads(); + + } + if (threadIdx.x == 0) { + agg[blockIdx.x] = agg_val; + } + +} + +template +struct NonZeroOp { + __host__ __device__ __forceinline__ int operator()(const T& a) const { + return (a != T(0)); + } +}; + +template +constexpr int block_threads(){ + if constexpr (size >=16) { + return 128; + } else if constexpr (size >=8) { + return 256; + } else { + return 512; + } +} + +template +inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * output, ScanOpT scan_op, int64_t num_items) { + static_assert(std::is_same_v>, ""); + constexpr int BLOCK_THREADS = block_threads(); + constexpr int ITEMS_PER_THREAD = 16; + auto grid_size = (num_items + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD); + const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + const int iters_per_cta = (grid_size + num_sms - 1)/num_sms; + grid_size = std::min(num_sms, grid_size); + // simple reduction in scan kernel handles at most 2 items per thread + TORCH_INTERNAL_ASSERT(2 * BLOCK_THREADS >= grid_size); + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto agg = allocator.allocate(grid_size * sizeof(scalar_t)); + calc_block_sums + <<>>( + input, (scalar_t*)agg.get(), num_items, iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + final_scan_kernel + <<>>( + input, output, (scalar_t*)agg.get(), num_items, iters_per_cta); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#endif + +template +inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) { +#if defined(USE_ROCM) + //For ROCm, use hipCUB chained iterators + CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan, + input, + output, + scan_op, + init_value, + num_items, + at::cuda::getCurrentCUDAStream()); + C10_HIP_KERNEL_LAUNCH_CHECK(); +#else + // non synchronizing cub call + // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, + // so split at int_max/2 + int size_cub = std::min(num_items, max_cub_size); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, + input, + output, + scan_op, + init_value, + size_cub, + at::cuda::getCurrentCUDAStream()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) { + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT)); + auto first_elem_ptr = reinterpret_cast(first_elem.get()); + + size_cub = std::min(num_items - i, max_cub_size); + impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + output + i - 1, + input + i - 1, + first_elem_ptr, + scan_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#if !CUB_SUPPORTS_FUTURE_VALUE() + auto input_ = impl::chained_iterator{ + input + i, first_elem_ptr}; + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, + input_, + output + i, + scan_op, + size_cub, + at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, + input + i, + output + i, + scan_op, + ::at_cuda_detail::cub::FutureValue(first_elem_ptr), + size_cub, + at::cuda::getCurrentCUDAStream()); +#endif + } +#endif +} + +#if CUB_SUPPORTS_SCAN_BY_KEY() + +template +inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub InclusiveSumByKey does not support more than INT_MAX elements"); +#if !defined(USE_ROCM) + CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey, + keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey, + keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); +#endif +} + +template +inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub InclusiveSumByKey does not support more than INT_MAX elements"); +#if !defined(USE_ROCM) + CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey, + keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); +#else + CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey, + keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); +#endif +} + +#endif + +template +void unique(InputIteratorT input, OutputIteratorT output, + NumSelectedIteratorT num_selected_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub unique does not support more than INT_MAX elements"); + CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, + input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); +} + +template +void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out, + LengthOutputIteratorT length_out, int64_t num_items) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub run_length_encode does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, + input, output, counts_out, length_out, num_items, + at::cuda::getCurrentCUDAStream()); +} + +template +void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) { + TORCH_CHECK(num_items <= std::numeric_limits::max(), + "cub reduce does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce, + input, output, num_items, op, init, + at::cuda::getCurrentCUDAStream()); + +} + +} // namespace at::cuda::cub diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub.h new file mode 100644 index 0000000000000000000000000000000000000000..b18f1438a7088ccea4ae5cc9e1ddb114d7941015 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub.h @@ -0,0 +1,87 @@ +#pragma once +#include +#include +#include + +// NOTE: These templates are intentionally not defined in this header, +// which aviods re-compiling them for each translation unit. If you get +// a link error, you need to add an explicit instantiation for your +// types in cub.cu + +namespace at::cuda::cub { + +inline int get_num_bits(uint64_t max_key) { + int num_bits = 1; + while (max_key > 1) { + max_key >>= 1; + num_bits++; + } + return num_bits; +} + +namespace detail { + +// radix_sort_pairs doesn't interact with value_t other than to copy +// the data, so we can save template instantiations by reinterpreting +// it as an opaque type. +template struct alignas(N) OpaqueType { char data[N]; }; + +template +void radix_sort_pairs_impl( + const key_t *keys_in, key_t *keys_out, + const OpaqueType *values_in, OpaqueType *values_out, + int64_t n, bool descending, int64_t begin_bit, int64_t end_bit); + +} // namespace detail + +template +void radix_sort_pairs( + const key_t *keys_in, key_t *keys_out, + const value_t *values_in, value_t *values_out, + int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) { + static_assert(std::is_trivially_copyable_v || + AT_ROCM_ENABLED(), // ROCm incorrectly fails this check for vector types + "radix_sort_pairs value type must be trivially copyable"); + // Make value type opaque, so all inputs of a certain size use the same template instantiation + using opaque_t = detail::OpaqueType; + static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0, + "This size of value_t is not instantiated. Please instantiate it in cub.cu" + " and modify this check."); + static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned"); + detail::radix_sort_pairs_impl( + keys_in, keys_out, + reinterpret_cast(values_in), + reinterpret_cast(values_out), + n, descending, begin_bit, end_bit); +} + +template +void radix_sort_keys( + const key_t *keys_in, key_t *keys_out, + int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8); + +// NOTE: Intermediate sums will be truncated to input_t precision +template +void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n); + +template +void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { + return inclusive_sum_truncating(input, output, n); +} + +// NOTE: Sums are done is common_type +template +void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n); + +template +void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { + return exclusive_sum_in_common_type(input, output, n); +} + +void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n); +inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) { + return mask_exclusive_sum( + reinterpret_cast(mask), output_idx, n); +} + +} // namespace at::cuda::cub diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub_definitions.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub_definitions.cuh new file mode 100644 index 0000000000000000000000000000000000000000..db7cc9120a099e36e9a63a4f424d9ef37b9b87fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/cub_definitions.cuh @@ -0,0 +1,53 @@ +#pragma once + +#if !defined(USE_ROCM) +#include // for CUDA_VERSION +#endif + +#if !defined(USE_ROCM) +#include +#else +#define CUB_VERSION 200001 +#endif + +// cub sort support for __nv_bfloat16 is added to cub 1.13 in: +// https://github.com/NVIDIA/cub/pull/306 +#if CUB_VERSION >= 101300 +#define CUB_SUPPORTS_NV_BFLOAT16() true +#else +#define CUB_SUPPORTS_NV_BFLOAT16() false +#endif + +// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: +// https://github.com/NVIDIA/cub/pull/326 +// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake +// starting from CUDA 11.5 +#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true +#else +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false +#endif + +// cub support for UniqueByKey is added to cub 1.16 in: +// https://github.com/NVIDIA/cub/pull/405 +#if CUB_VERSION >= 101600 +#define CUB_SUPPORTS_UNIQUE_BY_KEY() true +#else +#define CUB_SUPPORTS_UNIQUE_BY_KEY() false +#endif + +// cub support for scan by key is added to cub 1.15 +// in https://github.com/NVIDIA/cub/pull/376 +#if CUB_VERSION >= 101500 +#define CUB_SUPPORTS_SCAN_BY_KEY() 1 +#else +#define CUB_SUPPORTS_SCAN_BY_KEY() 0 +#endif + +// cub support for cub::FutureValue is added to cub 1.15 in: +// https://github.com/NVIDIA/cub/pull/305 +#if CUB_VERSION >= 101500 +#define CUB_SUPPORTS_FUTURE_VALUE() true +#else +#define CUB_SUPPORTS_FUTURE_VALUE() false +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..b0dac7a71e809d55c98dd18da2fe2048b0811ceb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include + +// TODO: No need to have this whole header, we can just put it all in +// the cpp file + +namespace at::cuda::detail { + +// Set the callback to initialize Magma, which is set by +// torch_cuda_cu. This indirection is required so magma_init is called +// in the same library where Magma will be used. +TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); + + +// The real implementation of CUDAHooksInterface +struct CUDAHooks : public at::CUDAHooksInterface { + CUDAHooks(at::CUDAHooksArgs) {} + void init() const override; + Device getDeviceFromPtr(void* data) const override; + bool isPinnedPtr(const void* data) const override; + const Generator& getDefaultGenerator( + DeviceIndex device_index = -1) const override; + Generator getNewGenerator( + DeviceIndex device_index = -1) const override; + bool hasCUDA() const override; + bool hasMAGMA() const override; + bool hasCuDNN() const override; + bool hasCuSOLVER() const override; + bool hasCuBLASLt() const override; + bool hasROCM() const override; + const at::cuda::NVRTC& nvrtc() const override; + DeviceIndex current_device() const override; + bool isBuilt() const override {return true;} + bool isAvailable() const override {return hasCUDA();} + bool hasPrimaryContext(DeviceIndex device_index) const override; + Allocator* getCUDADeviceAllocator() const override; + Allocator* getPinnedMemoryAllocator() const override; + bool compiledWithCuDNN() const override; + bool compiledWithMIOpen() const override; + bool supportsDilatedConvolutionWithCuDNN() const override; + bool supportsDepthwiseConvolutionWithCuDNN() const override; + bool supportsBFloat16ConvolutionWithCuDNNv8() const override; + bool hasCUDART() const override; + long versionCUDART() const override; + long versionCuDNN() const override; + long versionMIOpen() const override; + std::string showConfig() const override; + double batchnormMinEpsilonCuDNN() const override; + int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override; + void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override; + int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override; + void cuFFTClearPlanCache(DeviceIndex device_index) const override; + int getNumGPUs() const override; + DeviceIndex deviceCount() const override; + DeviceIndex getCurrentDevice() const override; + +#ifdef USE_ROCM + bool isGPUArch(const std::vector& archs, DeviceIndex device_index = -1) const override; +#endif + void deviceSynchronize(DeviceIndex device_index) const override; +}; + +} // at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h new file mode 100644 index 0000000000000000000000000000000000000000..1f80c863b63944a25aacb3aa8b95d0b82b6c110b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h @@ -0,0 +1,151 @@ +// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states. +// These handles are tied to device, and these libraries requires/recommends not to +// share handles across host threads. +// +// These libraries recommend using one handle per host thread. We may not want to do +// this because threads are relatively light-weight, but creating and destroying +// handles is expensive (destroying the handle causes synchronizations). DataParallel, +// for example, creates new threads for each forward pass. +// +// This file implements a handle pool mechanism. The handle pool returns handles on +// demand as threads request them. If all existing handles in the pool are in use, +// it creates a new one. As threads terminate, they release handles back into the pool. +// In this way, the handle pool never creates more handles than the high-water mark of +// active threads, so it's efficient with DataParallel. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace at::cuda { namespace { + +template +struct DeviceThreadHandlePool : public std::enable_shared_from_this> { + + struct Handle { + Handle_t handle; + Handle(bool create = false) : handle(nullptr) + { + if(create) Create(&handle); + } + // std::vector.emplace() and push_back() may route through temporaries and call + // copy/move constructors along the way. If this is the case, we don't want + // the destructors of temporaries to call cudnnDestroy on the handle. + // We can achieve safety (for the narrow case of stashing within std::vectors) + // by making Handle moveable but not copyable, and transferring handle ownership + // to the latest constructed object. This is not a substitute for full-blown + // reference counting, but reference counting may be overkill here. + // Another alternative is to wrap the saved Handles in unique_ptrs, i.e., + // unordered_map>> created_handles; + Handle(const Handle& rhs) = delete; + // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom + Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); } + // operator= takes argument by value + Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; } + ~Handle() { + if(handle) Destroy(handle); + } + }; + + std::mutex mutex; + + // Handles are lazily created as different threads request them, + // but are never destroyed until the end of the process. + // The maximum number of handles this process will create for each device is equal + // to the high-water mark of the number of concurrently active threads that request + // handles for that device. + // When threads terminate, they release their handles back into the pool for reuse. + // Otherwise, new handles would be created every time new threads were spawned, + // resulting in poor performance for Python modules that repeatedly or frequently + // spawned new sets of threads (like DataParallel, which creates a new set of threads + // for each forward pass). + // + // To prevent potential deadlocks, we explicitly choose not to cap the number + // of handles that are created per device. + // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device, + // only 4 can make forward progress at any time. The other 4 will not release their + // handles until they exit, so the fifth cannot make progress until then. This is + // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an + // intermediate point (ie, before any of them have exited). We have no way to anticipate + // or enforce that user threads will not attempt such intermediate synchronization. + // The only way to ensure safety is to avoid imposing a cap on the number of handles. + std::unordered_map> created_handles; + std::unordered_map> available_handles; + + // PoolWindow lazily creates and caches the handles that a particular thread is using, + // so in the common case handle access doesn't incur either handle creation or a mutex lock. + class PoolWindow + { + public: + PoolWindow(std::shared_ptr parent): weak_parent(std::move(parent)) {} + ~PoolWindow(){ release(); } + + Handle_t reserve(int device) + { + // If this thread already has a handle for this device, return it + if(my_handles.find(device) != my_handles.end()) + return my_handles[device]; + + // otherwise, either grab a handle from the pool if one is available, + // or if not, create a new one. + auto parent = weak_parent.lock(); + TORCH_CHECK(parent, "Cannot create handle during program termination"); + std::lock_guard guard(parent->mutex); + + if(parent->available_handles[device].size() > 0) + { + my_handles[device] = parent->available_handles[device].back(); + parent->available_handles[device].pop_back(); + } + else + { + // In local testing, I do observe that emplace_back sometimes routes through temporaries + // that incur move-constructor and destructor calls. See comments in Handle above. + parent->created_handles[device].emplace_back(true /*create*/); + my_handles[device] = parent->created_handles[device].back().handle; + } + + return my_handles[device]; + } + + private: + // Stores the per-device handles currently owned by this thread + std::unordered_map my_handles; + + std::weak_ptr weak_parent; + + // Called by the destructor. Releases this thread's handles back into the pool. + void release() { + if(my_handles.size() > 0) { + auto parent = weak_parent.lock(); + if (!parent) { + // If this thread exits after atexit handlers have completed, the + // cuda context itself may be invalid, so we must leak the handles. + return; + } + + std::lock_guard guard(parent->mutex); + for(auto d_h : my_handles) + parent->available_handles[d_h.first].push_back(d_h.second); + } + } + }; + + // Warning: + // If you want to change this function, be aware that this function will be called + // by multiple threads and there is no mutex guarding the call of this function, so + // make sure your implementation is thread-safe. + PoolWindow *newPoolWindow() { + // The returned pointer will be owned by a thread local variable + // so that different threads does not share the same PoolWindow. + return new PoolWindow(this->shared_from_this()); + } +}; + +}} // namespace at::cuda::detail:: diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..367ab10d3d3bb1de30412afc047fc25154247701 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +namespace at::cuda::detail { + +TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t); +using at::native::canUse32BitIndexMath; + +template +TensorInfo +getTensorInfo(const at::TensorBase &t) { + IndexType sz[MAX_TENSORINFO_DIMS]; + IndexType st[MAX_TENSORINFO_DIMS]; + + int dims = t.dim(); + for (int i = 0; i < dims; ++i) { + sz[i] = t.size(i); + st[i] = t.stride(i); + } + + scalar* data_ptr = nullptr; + + if constexpr (std::is_const_v) { + data_ptr = t.const_data_ptr(); + } else { + data_ptr = t.mutable_data_ptr(); + } + + return TensorInfo( + data_ptr, dims, sz, st); +} + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e8a26b5e06a6ffeeeaf5df26f09175a2c903aa01 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh @@ -0,0 +1,124 @@ +#pragma once + +#include +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#include +#endif + +namespace at::cuda::detail { + +// A utility class to implement integer division by multiplication, given a fixed +// divisor. +// +// WARNING: The fast divider algorithm is only implemented for unsigned int; +// otherwise we default to plain integer division. For unsigned int, +// we further assume that the dividend is at most INT32_MAX. Thus, +// IntDivider must NOT be used for general integer division. +// +// This reduced range is enough for our purpose, and it allows us to +// slightly simplify the computation. +// +// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1< 0), we can find a "magic number" m (2^N +// <= m < 2^(N+1)) and shift s such that: +// +// \floor(n / d) = \floor((m * n) / 2^(N+s)). +// +// Given such m and s, the integer division can be then implemented as: +// +// let m' = m - 2^N // 0 <= m' < 2^N +// +// fast_integer_division(n): +// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned +// // integer. Then take the higher N bits. +// t = (m' * n) >> N +// +// // Here we use the fact that n is less than 2^(N-1): otherwise the value +// // of (t + n) may not fit in an N-bit integer. +// return (t + n) >> s +// +// Finding such a magic number is surprisingly easy: +// +// s = \ceil(\log_2 d) +// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic. +// +// See also: +// - Division by Invariant Integers Using Multiplication, +// Torbjörn Granlund and Peter L. Montgomery, 1994. +// +// - http://www.hackersdelight.org/magic.htm +// +// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + +// Result of div/mod operation stored together. +template +struct DivMod { + Value div, mod; + + C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { } +}; + +// Base case: we only have an implementation for uint32_t for now. For +// everything else, we use plain division. +template +struct IntDivider { + IntDivider() = default; + IntDivider(Value d) : divisor(d) { } + + C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; } + C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; } + C10_HOST_DEVICE inline DivMod divmod(Value n) const { + return DivMod(n / divisor, n % divisor); + } + + Value divisor; +}; + +// Implement fast integer division. +template <> +struct IntDivider { + static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int."); + + IntDivider() = default; + + IntDivider(unsigned int d) : divisor(d) { + assert(divisor >= 1 && divisor <= INT32_MAX); + + // TODO: gcc/clang has __builtin_clz() but it's not portable. + for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + m1 = magic; + assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits. + } + + C10_HOST_DEVICE inline unsigned int div(unsigned int n) const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and + // 'm1'. + unsigned int t = __umulhi(n, m1); + return (t + n) >> shift; +#else + // Using uint64_t so that the addition does not overflow. + uint64_t t = ((uint64_t) n * m1) >> 32; + return (t + n) >> shift; +#endif + } + + C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const { + return n - div(n) * divisor; + } + + C10_HOST_DEVICE inline DivMod divmod(unsigned int n) const { + unsigned int q = div(n); + return DivMod(q, n - q * divisor); + } + + unsigned int divisor; // d above. + unsigned int m1; // Magic number: m' above. + unsigned int shift; // Shift amounts. +}; + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..b20001569da7275b001ba631bbdd7e1b0ffac163 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +namespace at::cuda::detail { + +// CUDA: grid stride looping +// +// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment. +// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final +// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be +// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no +// further iterations and the overflowed value in i=_i_n_d_e_x is not used. +#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \ + int64_t _i_n_d_e_x = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; \ + for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x) + +#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int) + + +// Use 1024 threads per block, which requires cuda sm_2x or above +constexpr int CUDA_NUM_THREADS = 1024; + +// CUDA: number of blocks for threads. +inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) { + TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N); + constexpr int64_t max_int = std::numeric_limits::max(); + + // Round up division for positive number that cannot cause integer overflow + auto block_num = (N - 1) / max_threads_per_block + 1; + TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device"); + + return static_cast(block_num); +} + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h new file mode 100644 index 0000000000000000000000000000000000000000..95e52c94377bf568060c25f887454dbbaf2054b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h @@ -0,0 +1,11 @@ +#pragma once +#include +namespace at::cuda { +// Forward-declares at::cuda::NVRTC +struct NVRTC; + +namespace detail { +extern NVRTC lazyNVRTC; +} // namespace detail + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..60e1a19c1aacfe43830295f75fa6d2496e3b7d52 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// If element_sizes is nullptr, then the strides will be in bytes, otherwise +// the strides will be in # of elements. +// Operands that share the same shape, but may have different strides. +// OffsetCalculator iterates the tensor in a column-major order + +#if defined(USE_ROCM) +constexpr int MAX_DIMS = 16; +#else +constexpr int MAX_DIMS = 25; +#endif + +template +struct OffsetCalculator { + // We allow having negative strides to implement some operations like torch.flip + using stride_t = std::conditional_t, + index_t>; + // The offset for each argument. Wrapper around fixed-size array. + // On CUDA, zero sized array is not allowed, so when we are handling nullary + // operators, we need to create a size 1 offset to avoid compiler failure. + // This size 1 offset is just a placeholder, and we will not use it. + using offset_type = std::array(NARGS, 1)>; + + // if element_sizes is nullptr, then the strides will be in bytes, otherwise + // the strides will be in # of elements. + OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) { + TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims"); + for (int i=0; i < dims; i++){ + sizes_[i] = at::cuda::detail::IntDivider(sizes[i]); + for (int arg = 0; arg < NARGS; arg++) { + int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]); + strides_[i][arg] = strides[arg][i] / element_size; + } + } + } + + C10_HOST_DEVICE offset_type get(index_t linear_idx) const { + offset_type offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] = 0; + } + + #pragma unroll + for (int dim = 0; dim < MAX_DIMS; ++dim) { + if (dim == dims) { + break; + } + auto divmod = sizes_[dim].divmod(linear_idx); + linear_idx = divmod.div; + + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] += divmod.mod * strides_[dim][arg]; + } + + } + return offsets; + } + + int dims; + at::cuda::detail::IntDivider sizes_[MAX_DIMS]; + stride_t strides_[MAX_DIMS][std::max(NARGS, 1)]; +}; + +template +struct TrivialOffsetCalculator { + // The offset for each argument. Wrapper around fixed-size array. + // The offsets are in # of elements, not in bytes. + // On CUDA, zero sized array is not allowed, so when we are handling nullary + // operators, we need to create a size 1 offset to avoid compiler failure. + // This size 1 offset is just a placeholder, and we will not use it. + using offset_type = std::array(NARGS, 1)>; + + C10_HOST_DEVICE offset_type get(index_t linear_idx) const { + offset_type offsets; + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) { + offsets[arg] = linear_idx; + } + return offsets; + } +}; + +// Make an OffsetCalculator with byte offsets +template +static OffsetCalculator make_offset_calculator(const at::TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(N <= iter.ntensors()); + std::array strides; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i).data(); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data()); +} + +// Make an OffsetCalculator with element offsets +template +static OffsetCalculator make_element_offset_calculator( + const at::TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(N <= iter.ntensors()); + std::array strides; + std::array element_sizes; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator( + iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data()); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh new file mode 100644 index 0000000000000000000000000000000000000000..231cd167cacb4f9b4f2f48e159431b30f3d6dc28 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh @@ -0,0 +1,43 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +// Stores RNG state values. Passed as a kernel argument. +// See Note [CUDA Graph-safe RNG states]. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +namespace at { + +struct PhiloxCudaState { + PhiloxCudaState() = default; + // Called if graph capture is not underway + PhiloxCudaState(uint64_t seed, + uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxCudaState(int64_t* seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + // Public members, directly accessible by at::cuda::philox::unpack. + // If we made them private with getters/setters, the getters/setters + // would have to be __device__, and we can't declare __device__ in ATen. + union Payload { + uint64_t val; + int64_t* ptr; + }; + + Payload seed_{}; + Payload offset_{}; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a320000ae881faa7416bae6ed1f37793f357a73a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include + +namespace at::cuda::detail { + +#define MAX_TENSORINFO_DIMS 25 + +// CUDA kernel argument that defines tensor layout +template +struct TensorInfo { + TensorInfo(); + TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]); + + // Set the size of the given dimension to 1, as if it were a + // reduction dim (allows you to calculate offsets of the reduction + // slice) + void reduceDim(int dim); + + // See note on [collapse dims]. + int collapseDims(const int excludeDim = -1); + + // Contiguous tensors of more than one dimension are collapsed down + // to one tensor + __host__ __device__ inline bool isContiguous() const { + return (dims == 1 && strides[0] == 1); + } + + T* data; + IndexType sizes[MAX_TENSORINFO_DIMS]; + IndexType strides[MAX_TENSORINFO_DIMS]; + int dims; +}; + +template +TensorInfo::TensorInfo() { + data = nullptr; + dims = 0; +} + +template +TensorInfo::TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]) { + data = p; + dims = dim; + TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions"); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } +} + +template +void +TensorInfo::reduceDim(int dim) { + TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1"); + sizes[dim] = 1; +} + +template +int +TensorInfo::collapseDims(const int excludeDim) { + auto result = at::collapse_dims(sizes, strides, dims, excludeDim); + dims = std::get<1>(result); + return std::get<0>(result); +} + +// Translate a linear index for the apply to a T* offset; +// specialized on `Dims` to reduce nvcc compilation time +template +struct IndexToOffset { + static __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + // Uses static dims + for (int i = Dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +// Uses dynamic (runtime) instead of static (compiletime) dims +template +struct IndexToOffset { + static inline __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + for (int i = info.dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a458c756daf96173dec9d26e7078d49ea975f72 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh @@ -0,0 +1,34 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +namespace at::cuda::philox { + +// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether +// that instance was created with graph capture underway or not. +// See Note [CUDA Graph-safe RNG states]. +// +// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen. +// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable. +// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +__host__ __device__ __forceinline__ std::tuple +unpack(at::PhiloxCudaState arg) { + if (arg.captured_) { + // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". + // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. + // For most threads' reads it will hit in cache, so it shouldn't hurt performance. + return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +// Adapted from TE +// extract seed and offset from PhiloxCudaState +__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr); + +void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream); + +} // namespace at::cuda::philox diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/jiterator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/jiterator.h new file mode 100644 index 0000000000000000000000000000000000000000..a7a440cd7b614732d5ad16eb28c56e71bb20b6e1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/jiterator.h @@ -0,0 +1,40 @@ +#pragma once +#include + +#if AT_USE_JITERATOR() + +#include +#include +#include + +#include +#include + +namespace at::cuda { + +TORCH_CUDA_CPP_API c10::SmallVector CompileAndLaunchKernel( + const std::string& code_string, + const std::string& kernel_name, + const int num_outputs, + const c10::SmallVector& tensors, + const c10::SmallVector& extra_args, + bool return_by_ref); + +} // namespace at::cuda + +#else + +namespace at::cuda { + +TORCH_CUDA_CPP_API c10::SmallVector CompileAndLaunchKernel( + const std::string& code_string, + const std::string& kernel_name, + const int num_outputs, + const c10::SmallVector& tensors, + const c10::SmallVector& extra_args, + bool return_by_ref) { + TORCH_CHECK(false, "Jiterator is not supported"); + } +} // namespace at::cuda + +#endif // AT_USE_JITERATOR() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/jiterator_impl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/jiterator_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..334403b39f7a57373021eff71481d7067a56dfd5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/jiterator_impl.h @@ -0,0 +1,250 @@ +#pragma once +#include + +#if AT_USE_JITERATOR() + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at::native { + + +#define AT_FOR_8_CASES(_) \ + _(1) \ + _(2) \ + _(3) \ + _(4) \ + _(5) \ + _(6) \ + _(7) \ + _(8) + +#define AT_FOR_8_CASES_WITH_COMMA(_) \ + _(1) , \ + _(2) , \ + _(3) , \ + _(4) , \ + _(5) , \ + _(6) , \ + _(7) , \ + _(8) + +c10::SmallVector get_extra_args_typenames(const c10::SmallVector& extra_args) { + c10::SmallVector args_typenames(extra_args.size()); + for (const auto i : c10::irange(extra_args.size())) { + args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type()); + } + return args_typenames; +} + +int can_vectorize_up_to(at::ScalarType type, char* pointer) { + switch(type) { +#define DEFINE_CASE(ctype, scalartype) \ + case ScalarType::scalartype : return memory::can_vectorize_up_to(pointer); + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) +#undef DEFINE_CASE + + default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); + } +} + +// jitted version of the above +// See Note [Jiterator], this relies on the assumptions enumerated there +int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) { + const at::ScalarType common_dtype = iter.common_dtype(); + const at::ScalarType result_dtype = common_dtype; + + // Deals with output + int result = can_vectorize_up_to(result_dtype, static_cast(iter.data_ptr(0))); + + // Incorporates input(s) + for (auto i = 1; i < iter.ntensors(); ++i) { + result = std::min(result, can_vectorize_up_to(common_dtype, static_cast(iter.data_ptr(i)))); + } + + return result; +} + +template +static std::unique_ptr> make_unique_offset_calculator( + const TensorIteratorBase& iter) { + // array size can not be 0, this happens when N == 0 + constexpr int array_size = std::max(N, 1); + TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs())); + + std::array strides; + int64_t element_sizes[array_size]; + for (int i = 0; i < N; i++) { + int index = IS_INPUT ? i + iter.noutputs() : i; + strides[i] = iter.strides(index).data(); + element_sizes[i] = iter.element_size(index); + } + return std::make_unique>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +struct OffsetCalculatorVariant { +#define DEFINE_CASE(index) std::unique_ptr> + using OffsetCalculatorTypes = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + OffsetCalculatorVariant(const TensorIteratorBase& iter) { + int num = IS_INPUT ? iter.ninputs() : iter.noutputs(); + + switch(num) { +#define DEFINE_CASE(index) \ + case index : v = make_unique_offset_calculator(iter); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + default: + TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(v.get()); }, v); + } + + private: + OffsetCalculatorTypes v{}; +}; + +struct ArrayVariant { +// works for up to 8 input + 8 outputs +#define DEFINE_CASE(index) std::array, std::array + using ArrayTypes = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + ArrayVariant(const TensorIteratorBase& iter) { + int ntensors = iter.ntensors(); + switch(ntensors) { +#define DEFINE_CASE(index) \ + case index: array = std::array{}; break; \ + case index+8: array = std::array{}; break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors); + } + + std::visit([&](auto& a) { + for (auto i = 0; i < ntensors; ++i) { + a[i] = (char*)iter.data_ptr(i); + } + }, array); + } + + void* data_ptr() { + return std::visit([](auto & a){ return static_cast(&a); }, array); + } + +private: + ArrayTypes array; +}; + +struct TrivialOffsetCalculatorVariant { +#define DEFINE_CASE(index) TrivialOffsetCalculator + using TrivialOffsetCalculatorTypes = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + TrivialOffsetCalculatorVariant(int num) { + switch(num) { +#define DEFINE_CASE(index) \ + case index: v = TrivialOffsetCalculator(); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(&v); }, v); + } + +private: + TrivialOffsetCalculatorTypes v{}; +}; + +struct LoadWithCastVariant { +#define DEFINE_CASE(index) std::unique_ptr> + using LoadWithCastPtr = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + LoadWithCastVariant(const TensorIteratorBase& iter) { + int arity = iter.ninputs(); + switch(arity) { +#define DEFINE_CASE(index) \ + case index: v = std::make_unique>(iter); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(v.get()); }, v); + } + +private: + LoadWithCastPtr v{}; +}; + +struct StoreWithCastVariant { +#define DEFINE_CASE(index) std::unique_ptr> + using StoreWithCastPtr = std::variant< + AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) + >; +#undef DEFINE_CASE + + StoreWithCastVariant(const TensorIteratorBase& iter) { + int num = iter.noutputs(); + switch(num) { +#define DEFINE_CASE(index) \ + case index: v = std::make_unique>(iter); break; + + AT_FOR_8_CASES(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num); + } + } + + void* data_ptr() { + return std::visit([](auto & v){ return static_cast(v.get()); }, v); + } + +private: + StoreWithCastPtr v{}; +}; + +} // namespace at::native + + +#endif // AT_USE_JITERATOR() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..aba40d4f42eac74ee9435d904ac3fb82d1e988c4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace at::cuda { + +TORCH_CUDA_CPP_API const std::string &get_traits_string(); +TORCH_CUDA_CPP_API const std::string &get_cmath_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_body_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_half_body_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_math_string(); + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..ba5672352da115b5bb5629b932f862b4a8184b8f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h @@ -0,0 +1,694 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include + +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif +#include +#include + +namespace at::cuda::tunable { + +enum class BlasOp { + N = 0, + T = 1 +}; + +inline char BlasOpToString(BlasOp op) { + switch (op) { + case BlasOp::N: + return 'N'; + case BlasOp::T: + return 'T'; + } + TORCH_CHECK(false, "unrecognized BlasOp"); + return 'N'; +} + +template +inline const char* BLASTypeName(T v) { + return "unknown"; +} + +template <> +inline const char* BLASTypeName(float v) { + return "f32_r"; +} + +template <> +inline const char* BLASTypeName(double v) { + return "f64_r"; +} + +template <> +inline const char* BLASTypeName(BFloat16 v) { + return "bf16_r"; +} + +template <> +inline const char* BLASTypeName(Half v) { + return "f16_r"; +} + +//https://github.com/ROCm/hipBLASLt/blob/develop/library/src/include/auxiliary.hpp#L175 +template <> +inline const char* BLASTypeName(Float8_e4m3fn v) { + return "f8_r"; +} + +template <> +inline const char* BLASTypeName(Float8_e5m2 v) { + return "bf8_r"; +} + +template <> +inline const char* BLASTypeName(Float8_e4m3fnuz v) { + return "f8_fnuz_r"; +} + +template <> +inline const char* BLASTypeName(Float8_e5m2fnuz v) { + return "bf8_fnuz_r"; +} + +template <> +inline const char* BLASTypeName(c10::complex v) { + return "f64_r"; +} + +template <> +inline const char* BLASTypeName(c10::complex v) { + return "f32_r"; +} + +inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) { + std::string BLASType; + switch (scalar_type) { + case c10::ScalarType::Float:{ + BLASType = "f32_r"; + break; + } + case c10::ScalarType::Double:{ + BLASType = "f64_r"; + break; + } + case c10::ScalarType::BFloat16:{ + BLASType = "bf16_r"; + break; + } + case c10::ScalarType::Half: { + BLASType = "f16_r"; + break; + } + case c10::ScalarType::Float8_e4m3fn: { + BLASType = "f8_r"; + break; + } + case c10::ScalarType::Float8_e5m2: { + BLASType = "bf8_r"; + break; + } + case c10::ScalarType::Float8_e4m3fnuz: { + BLASType = "f8_fnuz_r"; + break; + } + case c10::ScalarType::Float8_e5m2fnuz: { + BLASType = "bf8_fnuz_r"; + break; + } + case c10::ScalarType::ComplexFloat:{ + BLASType = "f32_c"; + break; + } + case c10::ScalarType::ComplexDouble:{ + BLASType = "f64_c"; + break; + } + default: + BLASType = "unknown"; + } + return BLASType; +} + +// Similar to Compute Type in GemmRocblas.h +template +inline std::string ComputeTypeFor() { + return "Unknown ComputeType"; +} + +// This is a union of the compute types for +// ROCBLAS and hipBLASLt. +template <> +inline std::string ComputeTypeFor() { + if (!at::globalContext().allowTF32CuBLAS()) { + return "f32_r"; + } else { + return "xf32_r"; + } +} + +template <> +inline std::string ComputeTypeFor() { + return "f64_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor>() { + return "f32_c"; +} + +template <> +inline std::string ComputeTypeFor>() { + return "f64_c"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +template <> +inline std::string ComputeTypeFor() { + return "f32_r"; +} + +// Convert opmath_type to string +template +inline std::string to_string_opmath(const at::opmath_type& value) { + if constexpr (std::is_same_v, c10::complex> || + std::is_same_v, c10::complex>) { + return fmt::format("({:.4f}, {:.4f})", value.real(), value.imag()); + } else { + return fmt::format("{:.4f}", value); + } +} + +// convert activation epilogue to string +inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivationEpilogue& value) { + switch (value) { + case at::cuda::blas::GEMMAndBiasActivationEpilogue::None: + return std::string("None"); + break; + case at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU: + return std::string("RELU"); + break; + case cuda::blas::GEMMAndBiasActivationEpilogue::GELU: + return std::string("GELU"); + break; + default: + return std::string("unknown"); + } +} + +namespace detail { + +static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) { + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA); + // comparison done as 1D tensor + at::Tensor ref = at::from_blob(c, {size}, options); + at::Tensor oth = at::from_blob(other_c, {size}, options); + at::Tensor ref_float = ref.to(at::kFloat); + at::Tensor oth_float = oth.to(at::kFloat); + std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + double last_succeed_atol = 1; + double last_succeed_rtol = 1; + for (auto& atol : atols) { + for (auto& rtol : rtols) { + if (at::allclose(ref_float, oth_float, rtol, atol)) { + last_succeed_atol = atol; + last_succeed_rtol = rtol; + } + } + } + if (last_succeed_atol == 1) { + return false; + } + else { + TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); + } + + return true; +} + +} + +// Note on GetSizeA et al. +// Tensors can be dense or arbitrarily strided. We only need our copies to be large enough. +// Our copies must be at least as large as the m n k shapes dictate, but could be larger +// depending on the lda ldb ldc values. Similarly for the batched case. + +template +struct GemmParams : OpParams { + GemmParams() = default; + + std::string BLASSignature() const override { + std::string alpha_str = to_string_opmath(alpha); + std::string beta_str = to_string_opmath(beta); + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, bias_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, alpha_str, beta_str, transa, transb, + BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), ComputeTypeFor(), ComputeTypeFor(), ComputeTypeFor()); + } + + std::string Signature() const override { + return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc); + } + + size_t GetSizeA() const { + size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t size_dense = m * k; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k); + size_t size_dense = k * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = ldc * n; + size_t size_dense = m * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmParams* DeepCopy(bool duplicate_inputs) const { + GemmParams* copy = new GemmParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + at::opmath_type alpha; + const T* a{}; + int64_t lda{}; + const T* b{}; + int64_t ldb{}; + at::opmath_type beta; + T* c{}; + int64_t ldc{}; +private: + bool duplicate_inputs_{false}; +}; + +template +struct GemmAndBiasParams : OpParams { + std::string BLASSignature() const override { + std::string alpha_str = to_string_opmath(alpha); + std::string activation_str = to_string_epilogue(activation); + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "alpha: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, activation: %s, bias_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, alpha_str, transa, transb, + BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(T{}), activation_str, BLASTypeName(T{}), ComputeTypeFor(), ComputeTypeFor(), ComputeTypeFor()); + } + + std::string Signature() const override { + return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc); + } + + size_t GetSizeA() const { + size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t size_dense = m * k; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k); + size_t size_dense = k * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = ldc * n; + size_t size_dense = m * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const { + GemmAndBiasParams* copy = new GemmAndBiasParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmAndBiasParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + at::opmath_type alpha{}; + const T* a{}; + int64_t lda{}; + const T* b{}; + int64_t ldb{}; + T* c{}; + int64_t ldc{}; + const T* bias{}; + at::cuda::blas::GEMMAndBiasActivationEpilogue activation{}; +private: + bool duplicate_inputs_{false}; +}; + +template +struct GemmStridedBatchedParams : OpParams { + std::string BLASSignature() const override { + std::string alpha_str = to_string_opmath(alpha); + std::string beta_str = to_string_opmath(beta); + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: %ld, stride_b: %ld, stride_c: %ld, stride_d: %ld, " + "alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: %ld, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, stride_a, stride_b, stride_c, stride_c, alpha_str, beta_str, transa, transb, batch, + BLASTypeName(T{}), BLASTypeName(T{}), BLASTypeName(C_Dtype{}), BLASTypeName(T{}), ComputeTypeFor(), ComputeTypeFor()); + } + + std::string Signature() const override { + return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, batch, lda, ldb, ldc); + } + + size_t GetSizeA() const { + size_t size_stride = stride_a * batch; + size_t size_dense = m * k * batch; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = stride_b * batch; + size_t size_dense = k * n * batch; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = stride_c * batch; + size_t size_dense = m * n * batch; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const { + GemmStridedBatchedParams* copy = new GemmStridedBatchedParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + // NOLINTNEXTLINE(*const-cast*) + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + // NOLINTNEXTLINE(*const-cast*) + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmStridedBatchedParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + at::opmath_type alpha{}; + const T* a{}; + int64_t lda{}; + int64_t stride_a{}; + const T* b{}; + int64_t ldb{}; + int64_t stride_b{}; + at::opmath_type beta; + C_Dtype* c{}; + int64_t ldc{}; + int64_t stride_c{}; + int64_t batch{}; +private: + bool duplicate_inputs_{false}; +}; + +template +struct ScaledGemmParams : OpParams { + ScaledGemmParams() = default; + + std::string BLASSignature() const override { + // Excluding use_fast_accum and use_rowise booleans for now + if (bias_ptr == nullptr) { + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, transa, transb, + ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), + ComputeTypeFor(), ComputeTypeFor()); + } + else { + return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, " + "transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, bias_type: %s, scale_type: %s, compute_type: %s }", + m, n, k, lda, ldb, ldc, ldc, transa, transb, + ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(bias_dtype), + ComputeTypeFor(), ComputeTypeFor()); + } + } + + std::string Signature() const override { + // In Blas.cpp, code defaults to a bias_dtype of Half even when there is no bias vector. + // Search for this line:: + // params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; + // + // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector. + return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s", + transa, transb, m, n, k, lda, ldb, ldc, use_rowwise, + bias_ptr == nullptr ? "None" : at::toString(bias_dtype)); + } + + size_t GetSizeA() const { + size_t size_stride = lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t size_dense = m * k; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeB() const { + size_t size_stride = ldb * ((transb == 'n' || transb == 'N') ? n : k); + size_t size_dense = k * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSizeC() const { + size_t size_stride = ldc * n; + size_t size_dense = m * n; + return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + ScaledGemmParams* DeepCopy(bool duplicate_inputs) const { + ScaledGemmParams* copy = new ScaledGemmParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size); + copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + // NOLINTNEXTLINE(*const-cast*) + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(ScaledGemmParams *other) { + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + } + + char transa{}; + char transb{}; + int64_t m{}; + int64_t n{}; + int64_t k{}; + const void* a{}; + const void* a_scale_ptr{}; + int64_t lda{}; + ScalarType a_dtype{}; + ScalarType a_scale_dtype{}; + const void* b{}; + const void* b_scale_ptr{}; + int64_t ldb{}; + ScalarType b_dtype{}; + ScalarType b_scale_dtype{}; + const void* bias_ptr{}; + ScalarType bias_dtype{}; + void* c{}; + const void* c_scale_ptr{}; + int64_t ldc{}; + ScalarType c_dtype{}; + void* amax_ptr{}; + bool use_fast_accum{}; + bool use_rowwise{}; +private: + bool duplicate_inputs_{false}; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h new file mode 100644 index 0000000000000000000000000000000000000000..a23a2d720c5c47bf1de777697741fe82a6600813 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h @@ -0,0 +1,685 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define TORCH_HIPBLASLT_CHECK(EXPR) \ + do { \ + hipblasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \ + "hipblaslt error: ", \ + hipblasStatusToString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +template +constexpr hipDataType HipDataTypeFor(); + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_32F; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_16F; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_16BF; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_64F; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_8F_E4M3_FNUZ; +} + +template <> +constexpr hipDataType HipDataTypeFor() { + return HIP_R_8F_E5M2_FNUZ; +} + +// This code is instantiated regardless of ROCm version. +// Prior to ROCm 6.3, we hard-code the known enum values. +template <> +constexpr hipDataType HipDataTypeFor() { +#if ROCM_VERSION >= 60300 + return HIP_R_8F_E4M3; +#else + return static_cast(28); +#endif +} + +template <> +constexpr hipDataType HipDataTypeFor() { +#if ROCM_VERSION >= 60300 + return HIP_R_8F_E5M2; +#else + return static_cast(29); +#endif +} + +// This type is not intended for matrix types but rather a scale factor. +// Return a dummy value to satisfy linker. +template <> +constexpr hipDataType HipDataTypeFor() { + return static_cast(500); +} + +template +int GetBatchFromParams(const GemmParams* params) { + return 1; +} + +template +int GetBatchFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetBatchFromParams(const GemmStridedBatchedParams* params) { + return params->batch; +} + +template +int GetBatchFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmStridedBatchedParams* params) { + return params->stride_a; +} + +template +int GetStrideAFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmStridedBatchedParams* params) { + return params->stride_b; +} + +template +int GetStrideBFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmStridedBatchedParams* params) { + return params->stride_c; +} + +template +int GetStrideCFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +float GetAlphaFromParams(const GemmParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const GemmAndBiasParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const GemmStridedBatchedParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const ScaledGemmParams* params) { + return 1.0; +} + +template +float GetBetaFromParams(const GemmParams* params) { + return params->beta; +} + +template +float GetBetaFromParams(const GemmAndBiasParams* params) { + return 0.0; +} + +template +float GetBetaFromParams(const GemmStridedBatchedParams* params) { + return params->beta; +} + +template +float GetBetaFromParams(const ScaledGemmParams* params) { + return 0.0; +} + +template +bool GetUseRowwiseFromParams(const GemmParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const GemmAndBiasParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const GemmStridedBatchedParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const ScaledGemmParams* params) { + return params->use_rowwise; +} + +template +const void* GetAScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const ScaledGemmParams* params) { + return params->a_scale_ptr; +} + +template +const void* GetBScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const ScaledGemmParams* params) { + return params->b_scale_ptr; +} + +template +const void* GetDScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const ScaledGemmParams* params) { + return params->c_scale_ptr; +} + +template +const void* GetBiasPointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { + return params->bias; +} + +template +const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetBiasPointerFromParams(const ScaledGemmParams* params) { + return params->bias_ptr; +} + +template +hipDataType GetBiasTypeFromParams(const GemmParams* params) { + return HIP_R_32F; +} + +template +hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { + return HipDataTypeFor(); +} + +template +hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { + return HIP_R_32F; +} + +template +hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { + return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { + return params->activation; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +static hipblasOperation_t _hipblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return HIPBLAS_OP_N; + case 't': + case 'T': + return HIPBLAS_OP_T; + case 'c': + case 'C': + return HIPBLAS_OP_C; + } + TORCH_CHECK(false, + "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +static char _charFromhipblasOp(hipblasOperation_t op) { + switch (op) { + case HIPBLAS_OP_N: + return 'N'; + case HIPBLAS_OP_T: + return 'T'; + case HIPBLAS_OP_C: + return 'C'; + } + TORCH_CHECK(false, + "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`"); +} + +static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) { + if (layout == BlasOp::N) { + return HIPBLAS_OP_N; + } + return HIPBLAS_OP_T; +} + +static size_t GetHipblasltWorkspaceSize() { + static const auto env = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); + // 256MB is max workspace size allowed for hipblaslt + // hipblaslt-bench uses 32MB + // recommendation from hipblaslt author was 76MB + // TunableOp hipBLASLt workspace size is aligned with + // PyTorch's default in CUDABlas.cpp (_parseChosenWorkspaceSize) + size_t workspace_size = 76*1024; + if (env) { + try { + workspace_size = std::stoi(env.value()); + } catch(std::invalid_argument const& e) { + TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", workspace_size, " KiB."); + } catch(std::out_of_range const& e) { + TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", workspace_size, " KiB."); + } + } + return workspace_size * 1024; +} + +template +struct HipBlasLtDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDABLAS_CHECK(destructor(x)); + } + } +}; + +template +class HipBlasLtDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< + hipblasLtMatmulDescOpaque_t, + &hipblasLtMatmulDescDestroy> { + public: + HipBlasLtMatmulDescriptor( + hipblasComputeType_t compute_type, + hipDataType scale_type) { + hipblasLtMatmulDesc_t raw_descriptor = nullptr; + TORCH_HIPBLASLT_CHECK( + hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { + TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +template +class HipblasltGemmOp : public Callable { + public: + HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {} + + TuningStatus Call(const ParamsT* params) override { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); + auto opa = _hipblasOpFromChar(params->transa); + auto opb = _hipblasOpFromChar(params->transb); + + TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen"); + + float alpha = GetAlphaFromParams(params); + float beta = GetBetaFromParams(params); + + hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; + if (opa == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda)); + } + if (opb == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb)); + } + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc)); + + // specific to batched gemmm + int batch = GetBatchFromParams(params); + if (batch > 1) { + int64_t stride_a = GetStrideAFromParams(params); + int64_t stride_b = GetStrideBFromParams(params); + int64_t stride_c = GetStrideCFromParams(params); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); + } + + hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; + if (at::globalContext().allowTF32CuBLAS()) { + computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + } + HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); + + // specific to scaled gemm + const void* mat1_scale_ptr = GetAScalePointerFromParams(params); + const void* mat2_scale_ptr = GetBScalePointerFromParams(params); + const void* result_scale_ptr = GetDScalePointerFromParams(params); + if (mat1_scale_ptr && mat2_scale_ptr) { +#ifdef HIPBLASLT_VEC_EXT + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr); + } + else +#endif + { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); + } +#ifdef HIPBLASLT_OUTER_VEC + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + } +#endif + } + if (result_scale_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); + } + + const void* bias_ptr = GetBiasPointerFromParams(params); + auto bias_datatype = GetBiasTypeFromParams(params); + if (bias_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); + auto activation = GetActivationFromParams(params); + if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); + } + else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); + } + else { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); + } + } + + size_t workspace_size = GetHipblasltWorkspaceSize(); + + auto op_handle = at::cuda::getCurrentCUDABlasLtHandle(); + + size_t ret_workspace_size = 0; + auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle, + matmul.descriptor(), + &alpha, + mat_a, + mat_b, + &beta, + mat_c, + mat_c, + algo_, + ret_workspace_size); + + if (status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size >= workspace_size) { + return FAIL; + } + } + else { + return FAIL; + } + + void* workspace_buffer = nullptr; + if (workspace_size > 0) { + workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size); + } + + TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle, + matmul.descriptor(), + &alpha, + params->a, + mat_a, + params->b, + mat_b, + &beta, + params->c, + mat_c, + params->c, + mat_c, + &algo_, + workspace_buffer, + workspace_size, + at::cuda::getCurrentCUDAStream())); + + //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c)); + if (workspace_size > 0) { + c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer); + } + return OK; + } + + private: + hipblasLtMatmulAlgo_t algo_; +}; + +template +auto GetHipBlasLtTypeStringAndOps() { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); + std::vector heuristic_result; +#if ROCM_VERSION == 60400 + // hipblaslt TT fp32 regression on ROCm 6.4, cannot use + if ((a_datatype == HIP_R_32F || b_datatype == HIP_R_32F || in_out_datatype == HIP_R_32F) + && (transa_outer == HIPBLAS_OP_T && transb_outer == HIPBLAS_OP_T)) { + std::vector>>> ignore; + return ignore; + } +#endif + + hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; + if (at::globalContext().allowTF32CuBLAS()) { + computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + } + + hipblasLtHandle_t handle; + TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle)); + TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle, + hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + transa_outer, + transb_outer, + a_datatype, + b_datatype, + in_out_datatype, + in_out_datatype, + computeType, + heuristic_result)); + TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle)); + + int returned_algo_count = heuristic_result.size(); + std::vector>>> ret; + for (int i = 0; i < returned_algo_count; i++) { + auto algo = heuristic_result[i].algo; + int algo_index = hipblaslt_ext::getIndexFromAlgo(algo); + auto callable = std::make_unique>(algo); + std::string type_string = fmt::sprintf("Gemm_Hipblaslt_%d", algo_index); + ret.emplace_back(type_string, std::move(callable)); + } + + return ret; +} + +template +auto GetHipBlasLtGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtScaledGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +#undef TORCH_HIPBLASLT_CHECK + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h new file mode 100644 index 0000000000000000000000000000000000000000..857eddd85d10aaf4d0d039596ced32c29bb24c41 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#define ROCBLAS_BETA_FEATURES_API +#include + +#define TORCH_ROCBLAS_CHECK(EXPR) \ + do { \ + rocblas_status __err = EXPR; \ + TORCH_CHECK(__err == rocblas_status_success, \ + "rocblas error: ", \ + rocblas_status_to_string(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +template +constexpr rocblas_datatype RocBlasDataTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_bf16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +constexpr rocblas_datatype RocBlasComputeTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // FP16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // BF16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +auto DoCastForHalfOrBfloat16(const T fp) { + return fp; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const Half fp) { + // alpha and beta should be the same as compute_type, in Half case it is float. + float h = fp; + return h; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) { + // alpha and beta should be the same as compute_type, in bfloat16 case it is float. + float h = fp; + return h; +} + +static rocblas_operation _rocblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return rocblas_operation_none; + case 't': + case 'T': + return rocblas_operation_transpose; + case 'c': + case 'C': + return rocblas_operation_conjugate_transpose; + } + TORCH_CHECK(false, + "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +template +class RocblasGemmOp : public Callable> { + public: + RocblasGemmOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) + return FAIL; // no support for TF32 in rocBLAS + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, + params->b, input_output_type, params->ldb, + &h_b, + params->c, input_output_type, params->ldc, + params->c, input_output_type, params->ldc, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(fmt::sprintf("Gemm_Rocblas_%d", solutions[i]), std::move(callable))); + } + return ret; +} + +template +class RocblasGemmStridedBatchedOp : public Callable> { + public: + RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmStridedBatchedParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) + return FAIL; // no support for TF32 in rocBLAS + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_strided_batched_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, params->stride_a, + params->b, input_output_type, params->ldb, params->stride_b, + &h_b, + params->c, input_output_type, params->ldc, params->stride_c, + params->c, input_output_type, params->ldc, params->stride_c, + params->batch, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmStridedBatchedTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + // Sort the solutions in ascending order to make the solution vector deterministic across runs + std::sort(solutions.begin(), solutions.end()); + + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); + } + return ret; +} + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h new file mode 100644 index 0000000000000000000000000000000000000000..15ed5e769975a9d1d14d365610a0d352d375249c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h @@ -0,0 +1,50 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include + +namespace at::cuda::tunable { + +class StreamTimer : public ITimer { + public: + StreamTimer(); + ~StreamTimer() override; + + void Start() override; + + void End() override; + + float Duration() override; + + private: + cudaEvent_t start_{}; + cudaEvent_t end_{}; +}; + +class StreamTimerNoSync : public ITimer { + public: + StreamTimerNoSync(); + ~StreamTimerNoSync() override; + + void Start() override; + + void End() override; + + float Duration() override; + + private: + cudaEvent_t start_{}; + cudaEvent_t end_{}; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/Tunable.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/Tunable.h new file mode 100644 index 0000000000000000000000000000000000000000..5e885d4764d29224fb3b3b3957fef4b8f5ae761f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/Tunable.h @@ -0,0 +1,241 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TUNABLE_LOGV(LEVEL, ...) getTuningContext()->Log(LEVEL, __VA_ARGS__) +#define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__) +#define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__) +#define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__) + +namespace at::cuda::tunable { + +enum TORCH_CUDA_CPP_API TuningStatus { + OK = 0, + FAIL = 1, + UNSUPPORTED = 2, +}; + +// Mapping from params signature to kernel id +class TORCH_CUDA_CPP_API ResultEntry { + public: + explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {} + explicit ResultEntry(std::string key, double time, std::string blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(std::move(blas_sig)) {} + bool operator==(const ResultEntry& other) const { return key_ == other.key_; } + bool operator!=(const ResultEntry& other) const { return key_ != other.key_; } + operator std::string () { return key_; } + std::string GetKey() const { return key_; } + double GetTime() const { return time_; } + friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry); + static ResultEntry Null() { return ResultEntry("Null", 0.0); } + static ResultEntry Default() { return ResultEntry("Default", 0.0); } + + private: + std::string key_; + double time_; + std::string blas_sig_; +}; + +typedef std::unordered_map KernelMap; +typedef std::unordered_map ResultsMap; +typedef std::unordered_map> UntunedMap; + +struct TORCH_CUDA_CPP_API TuningResults { + // Validates if these results are compatible with the libraries + std::unordered_map validators; + + // Mapping from Callable signature to Callable's tuning result + ResultsMap results; +}; + +class TORCH_CUDA_CPP_API TuningResultsManager { + public: + TuningResultsManager() = default; + ~TuningResultsManager() = default; + + KernelMap Lookup(const std::string& op_signature); + + ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature); + + void AddImpl(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best, + KernelMap& kernel_map); + + void Add(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best); + + void Delete(const std::string& op_signature, const std::string& params_signature); + + void DisjointMergeImpl( + const std::string& op_signature, + const KernelMap& kernel_map, + /*out*/ ResultsMap& results); + + void Load(const ResultsMap& results_to_load); + + ResultsMap Dump(); + + void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map); + + size_t GetSize(); + + void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, + const std::string& params_signature, const std::string& blas_signature); + private: + std::mutex lock_; + ResultsMap results_; + UntunedMap untuned_results_; + +}; + +class TORCH_CUDA_CPP_API TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + ~TuningResultsValidator() = default; + + std::unordered_map GetAllValidators() const; + TuningStatus ValidateAll(const std::unordered_map& to_validate) const; + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + protected: + static std::string GetPyTorchVersion() ; + TuningStatus ValidatePyTorchVersion(const std::string& value) const; + + public: + static constexpr const std::array mandatory_keys{"PT_VERSION"}; + + private: + GetValidateFuncs validators_; +}; + +class TORCH_CUDA_CPP_API TuningContext { + public: + TuningContext(); + ~TuningContext(); + TuningContext(TuningContext &) = delete; + TuningContext(TuningContext &&) = delete; + TuningContext &operator=(TuningContext &) = delete; + TuningContext &operator=(TuningContext &&) = delete; + + void EnableTunableOp(bool value); + bool IsTunableOpEnabled() const; + + void EnableTuning(bool value); + bool IsTuningEnabled() const; + + void EnableRecordUntuned(bool value); + bool IsRecordUntunedEnabled() const; + std::ofstream& GetUntunedFile(); + + void EnableNumericsCheck(bool value); + bool IsNumericsCheckEnabled() const; + + void SetMaxTuningDurationMs(int max_duration_ms); + int GetMaxTuningDurationMs() const; + + void SetMaxTuningIterations(int max_iter); + int GetMaxTuningIterations() const; + + void SetMaxWarmupDurationMs(int max_duration_ms); + int GetMaxWarmupDurationMs() const; + + void SetMaxWarmupIterations(int max_iter); + int GetMaxWarmupIterations() const; + + void EnableICacheFlush(bool value); + bool IsICacheFlushEnabled() const; + + void SetRotatingBufferSize(int size); + int GetRotatingBufferSize() const; + + TuningResultsManager& GetTuningResultsManager(); + + TuningResultsValidator& GetTuningResultsValidator(); + + TuningResults GetTuningResults(); + + TuningStatus LoadTuningResults(const TuningResults& tr); + + void SetFilename(const std::string& filename, bool insert_device_ordinal=false); + std::string GetFilename() const; + + void WriteFileOnExit(bool value); + + bool ReadFile(const std::string& filename={}); + bool WriteFile(const std::string& filename={}); + + template + void Log(int level, Types... args) { + if (GetLogOkay() && GetLogLevel() >= level) { + GetLog() << c10::str(args...) << std::endl; + } + } + + private: + std::string GetLogFilename() const; + int GetLogLevel() const; + bool GetLogOkay() const; + std::ostream& GetLog() const; + + bool enable_; + bool tuning_enable_; + bool record_untuned_enable_; + bool manager_initialized_; + bool write_file_on_exit_; + bool numerics_check_enable_; + int max_tuning_duration_ms_; + int max_tuning_iterations_; + int max_warmup_duration_ms_; + int max_warmup_iterations_; + bool icache_flush_; + int rotating_buffer_size_; + mutable TuningResultsManager manager_; + mutable c10::once_flag manager_init_once_; + TuningResultsValidator validator_; + std::string filename_; + std::ofstream untuned_file_; + size_t results_count_from_input_file_; + bool is_shutting_down_; +}; + +TORCH_CUDA_CPP_API TuningContext* getTuningContext(); + +class ITimer { + public: + ITimer() = default; + virtual ~ITimer() = default; + + virtual void Start() = 0; + virtual void End() = 0; + + /// Computes the elapsed time in milliseconds between Start() and End() + virtual float Duration() = 0; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h new file mode 100644 index 0000000000000000000000000000000000000000..d7e2835b1b10907defc4a5a2f350b3b21c296c47 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h @@ -0,0 +1,327 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#ifdef USE_ROCM +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::cuda::tunable { + +template +class DefaultGemmOp : public Callable> { + public: + TuningStatus Call(const GemmParams* params) override { + at::cuda::blas::gemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->beta, + params->c, params->ldc); + return OK; + } +}; + +static bool _transposeBoolFromChar(char op) { + return op == 't' || op == 'T'; +} + +template +class DefaultGemmAndBiasOp : public Callable> { + public: + TuningStatus Call(const GemmAndBiasParams* params) override { + at::cuda::blas::gemm_and_bias( + _transposeBoolFromChar(params->transa), + _transposeBoolFromChar(params->transb), + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->bias, + params->c, params->ldc, + params->activation); + return OK; + } +}; + +template +class DefaultGemmStridedBatchedOp : public Callable> { + public: + TuningStatus Call(const GemmStridedBatchedParams* params) override { + at::cuda::blas::bgemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, params->stride_a, + params->b, params->ldb, params->stride_b, + params->beta, + params->c, params->ldc, params->stride_c, + params->batch); + return OK; + } +}; + +template +class DefaultScaledGemmOp : public Callable> { + public: + TuningStatus Call(const ScaledGemmParams* params) override { + at::cuda::blas::scaled_gemm( + params->transa, + params->transb, + params->m, + params->n, + params->k, + params->a, + params->a_scale_ptr, + params->lda, + params->a_dtype, + params->a_scale_dtype, + params->b, + params->b_scale_ptr, + params->ldb, + params->b_dtype, + params->b_scale_dtype, + params->bias_ptr, + params->bias_dtype, + params->c, + params->c_scale_ptr, + params->ldc, + params->c_dtype, + params->use_fast_accum, + params->use_rowwise); + return OK; + } +}; + +template +inline bool IsZero(T v) { + return v == 0.0f; +} + +template <> +inline bool IsZero(BFloat16 v) { + return v.x == 0; +} + +template <> +inline bool IsZero(Half v) { + return float(v) == 0.0f; +} + +template <> +inline bool IsZero(c10::complex v) { + return v == 0.0; +} + +template <> +inline bool IsZero(c10::complex v) { + return v == 0.0f; +} + +template +inline const char* TypeName(T v) { + return "unknown"; +} + +template <> +inline const char* TypeName(float v) { + if (at::globalContext().allowTF32CuBLAS()) { + return "tf32"; + } else { + return "float"; + } +} + +template <> +inline const char* TypeName(double v) { + return "double"; +} + +template <> +inline const char* TypeName(BFloat16 v) { + return "BFloat16"; +} + +template <> +inline const char* TypeName(Half v) { + return "Half"; +} + +template <> +inline const char* TypeName(Float8_e4m3fn v) { + return "Float8_e4m3fn"; +} + +template <> +inline const char* TypeName(Float8_e5m2 v) { + return "Float8_e5m2"; +} + +template <> +inline const char* TypeName(Float8_e4m3fnuz v) { + return "Float8_e4m3fnuz"; +} + +template <> +inline const char* TypeName(Float8_e5m2fnuz v) { + return "Float8_e5m2fnuz"; +} + +template <> +inline const char* TypeName(Float8_e8m0fnu v) { + return "Float8_e8m0fnu"; +} + +template <> +inline const char* TypeName(c10::complex v) { + return "c10::complex"; +} + +template <> +inline const char* TypeName(c10::complex v) { + return "c10::complex"; +} + +template +class GemmTunableOp : public TunableOp> { + public: + GemmTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (!env_rocblas.has_value() || env_rocblas.value()) { + for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class GemmAndBiasTunableOp : public TunableOp> { + public: + GemmAndBiasTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class GemmStridedBatchedTunableOp : public TunableOp> { + public: + GemmStridedBatchedTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (!env_rocblas.has_value() || env_rocblas.value()) { + for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class ScaledGemmTunableOp : public TunableOp> { + public: + ScaledGemmTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } +#endif + + this->RegisterOp(std::string("Default"), std::make_unique>()); + } + + std::string Signature() override { + return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c", + TypeName(AT{}), + TypeName(BT{}), + TypeName(CT{}), + BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h new file mode 100644 index 0000000000000000000000000000000000000000..6ca9e213e1489d64b22dfe00d7633901117a2dad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h @@ -0,0 +1,430 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include + +namespace at::cuda::tunable { + +template +class Callable { + public: + virtual ~Callable() = default; + virtual TuningStatus Call(const ParamsT*) { + return FAIL; + } + virtual TuningStatus IsSupported(const ParamsT* params) { + return Call(params); + } +}; + +namespace { + +/** http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance */ + +class Stats { + public: + Stats() { + _n = 0UL; + _mean = 0.0; + _M2 = 0.0; + _sum = 0.0; + _min = 0.0; + _max = 0.0; + } + + void sample_value(const double x) { + double delta = 0; + _sum = _sum + x; + if (0UL == _n) { + _min = x; + _max = x; + } + else { + _min = _min < x ? _min : x; + _max = _max > x ? _max : x; + } + _n = _n + 1UL; + delta = x - _mean; + _mean = _mean + delta/_n; + _M2 = _M2 + delta * (x - _mean); + } + + double variance() const { + return _M2/(_n-1); + } + + double stddev() const { + return std::sqrt(variance()); + } + + unsigned long _n; + double _mean; + double _M2; + double _sum; + double _min; + double _max; +}; + +class FixedSizeStack { + private: + std::deque stack; + const size_t max_size; + + public: + FixedSizeStack(size_t size) : max_size(size) {} + + void push(const std::string& value) { + if (stack.size() >= max_size) { + stack.pop_front(); // Remove the oldest entry + } + stack.push_back(value); // Add new entry + } + + auto rbegin() { return stack.rbegin(); } + auto rend() { return stack.rend(); } +}; + +} // anonymous namespace + +template +class TunableOp { + public: + virtual ~TunableOp() = default; + + TuningStatus operator()(const ParamsT* params) { + ResultEntry result = ResultEntry::Null(); + TuningContext* ctx = getTuningContext(); + if (ctx->IsTunableOpEnabled()) { + auto& mgr = ctx->GetTuningResultsManager(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + auto blas_sig = params->BLASSignature(); + result = mgr.Lookup(op_sig, params_sig); + // If there is not previous tuning result been found, we do the tuning iff tuning is enabled + if (result == ResultEntry::Null()) { + if (ctx->IsTuningEnabled()) { + result = FindFastest(params); + mgr.Add(op_sig, params_sig, result); + } + else if (ctx->IsRecordUntunedEnabled()) { + // or record the gemm into file + mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig, blas_sig); + } + } + } + else { + result = ResultEntry::Default(); + } + if (result == ResultEntry::Null()) { + TUNABLE_LOG2("no result, using default"); + result = ResultEntry::Default(); + } + auto iter = ops_.find(result); + TORCH_CHECK(iter != ops_.end()); + return iter->second->Call(params); + } + + virtual std::string Signature() { + // According to C++17 standard https://wg21.link/n4659 section 15.7.4 + // > if the operand of typeid refers to the + // > object under construction or destruction, typeid yields the std::type_info object representing the constructor + // > or destructor’s class. + // So delay the op signature generation. + c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); }); + return signature_; + } + + protected: + void RegisterOp(const std::string& name, std::unique_ptr> op) { + this->op_names_.emplace_back(name); + this->ops_.emplace(name, std::move(op)); + } + + private: + static void WarmUp(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + for (size_t i = 0; i < num_iter; i++) { + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + } + + static double ProfileSimple(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + StreamTimerNoSync timer{}; + + // Small Mandatory Warmup + // Reduces outliers + for (size_t i = 0; i < 2; i++) { + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + + timer.Start(); + for (size_t i = 0; i < num_iter; i++) { + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + timer.End(); + return timer.Duration() / num_iter; + } + + static Stats ProfileStats(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + std::vector timer(num_iter); + + // Small Mandatory Warmup + // Reduces outliers + for (size_t i = 0; i < 2; i++) { + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + + for (size_t i = 0; i < num_iter; i++) { + timer[i].Start(); + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + timer[i].End(); + if (do_flush) { + at::cuda::flush_icache(); + } + } + Stats s; + for (size_t i = 0; i < num_iter; i++) { + s.sample_value(timer[i].Duration()); + } + return s; + } + + protected: + virtual ResultEntry FindFastest(const ParamsT* params) { + TuningContext* ctx = getTuningContext(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + auto blas_sig = params->BLASSignature(); + TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); + auto min_duration_ms = std::numeric_limits::infinity(); + std::string id_name = "Default"; + ParamsT* reference_params = nullptr; + auto top_solns = FixedSizeStack(5); + + // numeric check option is controlled by non-static env var, so check it once per tuned operator + bool do_numerics_check = ctx->IsNumericsCheckEnabled(); + + // calcaulte a reference answer for numerical check + if (do_numerics_check) { + reference_params = params->DeepCopy(false); + TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + } + + // need copies of params to reuse + // make as many copies as will fill the requested rotating buffer size, if requested + // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int + size_t rotating_size = ctx->GetRotatingBufferSize(); + bool use_buffer_rotation = (rotating_size > 0); + size_t param_size = params->GetSize(use_buffer_rotation); + size_t param_count = (rotating_size / param_size) + 1; + constexpr size_t MB = 1024ull*1024; + if (use_buffer_rotation) { + TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ", + "Needed Size: ", param_size/MB, " MiB. ", + "Needed number of param copies: ", param_count); + } + TORCH_CHECK(param_count > 0); + + std::vector reusable_params(param_count); + for (size_t i = 0; i < param_count; i++) { + reusable_params[i] = params->DeepCopy(use_buffer_rotation); + } + + // for rotating buffer + size_t offset = 0; + + for (size_t i = 0; i < op_names_.size(); i++) { + auto* candidate = ops_[op_names_[i]].get(); // borrow pointer + + if (do_numerics_check) { + ParamsT* numerical_params = params->DeepCopy(false); + auto status = candidate->Call(numerical_params); + if (status != OK) { + numerical_params->Delete(); + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + status = reference_params->NumericalCheck(numerical_params); + numerical_params->Delete(); + if (status != OK) { + TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + else { + auto status = candidate->Call(reusable_params[0]); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + + // collect a small profile + int approx_num_iter = 3; + auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset); + double approx_duration = s._mean; + // bail if too slow + if (approx_duration > 1.5 * min_duration_ms) { + TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + + // 2nd phase skip, more aggressive + approx_num_iter = 10; + s = ProfileStats(candidate, reusable_params, approx_num_iter, offset); + approx_duration = s._mean; + // bail if too slow + if (approx_duration > 1.15 * min_duration_ms) { + TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + + // for warmup does user set max duration, max iters, or both? + // warmup is skipped by default, i.e. warmup_iter = 0 + // warmup will be set to the non-zero value of max_warmup_duration + // or max_warmup_iter + // if both are non-zero, we take the smaller of the two. + double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); + int max_warmup_iter = ctx->GetMaxWarmupIterations(); + int warmup_iter = 0; // default + if (max_warmup_duration > 0) { + int duration_iters = max_warmup_duration / approx_duration; + if (max_warmup_iter > 0) { + warmup_iter = std::min(max_warmup_iter, duration_iters); + } + else { + warmup_iter = duration_iters; + } + } + else if (max_warmup_iter > 0) { + warmup_iter = max_warmup_iter; + } + + // for tuning does user set max duration, max iters, or both? + double max_tuning_duration = ctx->GetMaxTuningDurationMs(); + int max_tuning_iter = ctx->GetMaxTuningIterations(); + int tuning_iter = 100; // default + if (max_tuning_duration > 0) { + int duration_iters = max_tuning_duration / approx_duration; + if (max_tuning_iter > 0) { + tuning_iter = std::min(max_tuning_iter, duration_iters); + } + else { + tuning_iter = duration_iters; + } + } + else if (max_tuning_iter > 0) { + tuning_iter = max_tuning_iter; + } + // tuning must run at least 1 iteration + tuning_iter = std::max(1, tuning_iter); + + // do the full warmup followed by tuning + double warmup_ms = warmup_iter * approx_duration; + double tuning_ms = tuning_iter * approx_duration; + TUNABLE_LOG3("├──tuning using " + "warmup iters ", warmup_iter, " [", warmup_ms, " ms] " + "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ", + "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); + TUNABLE_LOG3("├──offset at ", offset); + WarmUp(candidate, reusable_params, warmup_iter, offset); + s = ProfileStats(candidate, reusable_params, tuning_iter, offset); + auto s_stddev = s.stddev(); + // Assume normal distribution. + // Solution with smallest mean + 2*sigma will be a better solution? + // if ((s._mean + 2*s_stddev) < (min_duration_ms + 2*min_stddev_ms)) { + if (s._mean < min_duration_ms) { + TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i], + " min ", s._min, + " max ", s._max, + " mean ", s._mean, + " std ", s_stddev); + min_duration_ms = s._mean; + id_name = op_names_[i]; + std::string current_soln = std::to_string(s._mean) + " " + op_names_[i]; + top_solns.push(current_soln); + } + else { + TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i], + " min ", s._min, + " max ", s._max, + " mean ", s._mean, + " std ", s_stddev); + } + } + + for (size_t i = 0; i < reusable_params.size(); i++) { + reusable_params[i]->Delete(); + } + if (reference_params) { + reference_params->Delete(); + } + + TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); + TUNABLE_LOG2("└──top five solutions for ", op_sig, '(', params_sig, ") "); + for (auto it = top_solns.rbegin(); it != top_solns.rend(); ++it) { + TUNABLE_LOG2(" ", *it); + } + return ResultEntry(id_name, min_duration_ms, blas_sig); + } + + private: + std::string CreateSignature() { +#ifndef _WIN32 + const auto* name = typeid(*this).name(); + // NOLINTNEXTLINE(*array*) + char buf[256]; + size_t buf_len = 256; + abi::__cxa_demangle(name, buf, &buf_len, nullptr); + buf[255] = '\0'; + return buf; +#else + return typeid(*this).name(); +#endif + } + + mutable c10::once_flag signature_init_once_; + std::string signature_; + + std::unordered_map>> ops_; + std::vector op_names_; +}; + +struct OpParams { + virtual ~OpParams() = default; + virtual std::string Signature() const = 0; + virtual std::string BLASSignature() const = 0; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Descriptors.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Descriptors.h new file mode 100644 index 0000000000000000000000000000000000000000..6c2492b12e6b9ba1cabf87502feafa751dc816fd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Descriptors.h @@ -0,0 +1,409 @@ +#pragma once + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8907 +#define USE_CUDNN_RNN_V8_API +#endif + +namespace at::native { + +std::string cudnnTypeToString(cudnnDataType_t dtype); + +// TODO: Add constructors for all of the descriptors + +inline int dataSize(cudnnDataType_t dataType) +{ + switch (dataType) { + case CUDNN_DATA_BFLOAT16: + case CUDNN_DATA_HALF: return 2; + case CUDNN_DATA_FLOAT: return 4; + default: return 8; + } +} + +// The stride for a size-1 dimensions is not uniquely determined; in +// fact, it can be anything you want, because the fact that the +// tensor is size 1 at this dimension means that you will never actually +// try advancing your pointer by this stride. +// +// However, CuDNN has a much more stringent requirement on strides: +// if you are passing a contiguous input, it better be the case +// that the stride for dim i is the product of the sizes of dims +// i+1 to the end. This stride is indeed uniquely determined. This +// function modifies 'stride' in place so this invariant holds. +template +static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) { + int64_t z = 1; + int index = 0; + std::vector permutation(dim); + + if (nhwc) { + permutation[index++] = 1; + } + for (int d = dim-1; d > 1; d--) { + permutation[index++] = d; + } + if (!nhwc) { + permutation[index++] = 1; + } + permutation[index++] = 0; + for (int d : permutation) { + if (size[d] == 1) { + stride[d] = z; + } else { + z *= size[d]; + } + } +} + +template +struct DescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + AT_CUDNN_CHECK(dtor(x)); + } + } +}; + +// A generic class for wrapping cuDNN descriptor types. All you need +// is to give the underlying type the Descriptor_t points to (usually, +// if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct), +// the constructor and the destructor. Subclasses are responsible +// for defining a set() function to actually set the descriptor. +// +// Descriptors default construct to a nullptr, and have a descriptor +// initialized the first time you call set() or any other initializing +// function. +template +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_CUDA_CPP_API Descriptor { + public: + // TODO: Figure out why const-correctness doesn't work here + + // Use desc() to access the underlying descriptor pointer in + // a read-only fashion. Most client code should use this. + // If the descriptor was never initialized, this will return + // nullptr. + T* desc() const { return desc_.get(); } + T* desc() { return desc_.get(); } + + // Use mut_desc() to access the underlying descriptor pointer + // if you intend to modify what it points to (e.g., using + // cudnnSetFooDescriptor). This will ensure that the descriptor + // is initialized. Code in this file will use this function. + T* mut_desc() { init(); return desc_.get(); } +protected: + void init() { + if (desc_ == nullptr) { + T* raw_desc = nullptr; + AT_CUDNN_CHECK(ctor(&raw_desc)); + desc_.reset(raw_desc); + } + } +private: + std::unique_ptr> desc_; +}; + +class TORCH_CUDA_CPP_API RNNDataDescriptor : public Descriptor< + cudnnRNNDataStruct, + &cudnnCreateRNNDataDescriptor, + &cudnnDestroyRNNDataDescriptor> { +public: + void set(const at::Tensor &t, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray); +private: + void set(cudnnDataType_t dataType, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray) { + AT_CUDNN_CHECK(cudnnSetRNNDataDescriptor(mut_desc(), dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, nullptr)); + } +}; + +class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< + cudnnTensorStruct, + &cudnnCreateTensorDescriptor, + &cudnnDestroyTensorDescriptor> { + public: + TensorDescriptor() = default; + explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { + set(t, pad); + } + + // Note [CuDNN broadcast padding] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // pad specifies the minimum dimensionality of the tensor descriptor + // we produce (it doesn't have anything to do with, e.g., convolution + // padding). If 't' is lower-dimensional than 'pad', the remaining + // dimensions (on the right) are padded with ones. This doesn't + // affect the underlying data layout. This is particularly useful for + // dealing with a peculiarity of the CuDNN API, which is that broadcasting in CuDNN is + // done in two steps: first, the client code is expected to pad out + // (the dimensions) input tensors to be the same dimension as the + // target broadcast, and then second, CuDNN takes of actually + // broadcasting size 1 dimensions. + + void set(const at::Tensor &t, size_t pad = 0); + void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0); + void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0); + + void print(); + +private: + void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc); + + void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { + std::vector strides_copy(stride, stride + dim); + fixSizeOneDimStride(dim, size, strides_copy.data(), nhwc); + AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, strides_copy.data())); + } +}; + +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); + +class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< + cudnnFilterStruct, + &cudnnCreateFilterDescriptor, + &cudnnDestroyFilterDescriptor> { + public: + void set(const at::Tensor &t, int64_t pad = 0) { + set(t, at::MemoryFormat::Contiguous, pad); + } + + void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0); + + void print(); +private: + void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) { + AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size)); + } +}; + +std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d); + +struct TORCH_CUDA_CPP_API ConvolutionDescriptor + : public Descriptor< + cudnnConvolutionStruct, + &cudnnCreateConvolutionDescriptor, + &cudnnDestroyConvolutionDescriptor> { + void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) { + cudnnDataType_t mathType = dataType; + if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT; + AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, + CUDNN_CROSS_CORRELATION, mathType)); + AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups)); + // See Note [behavior of cudnnFind and cudnnGet] + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH)); + if(dataType == CUDNN_DATA_HALF) { + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); + } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) { + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH)); + } + } +}; + +struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor + : public Descriptor< + cudnnSpatialTransformerStruct, + &cudnnCreateSpatialTransformerDescriptor, + &cudnnDestroySpatialTransformerDescriptor> { + void set(cudnnDataType_t dataType, int dim, int* size) { + AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size)); + } +}; + +// NOLINTNEXTLINE(bugprone-exception-escape) +struct TORCH_CUDA_CPP_API DropoutDescriptor + : public Descriptor< + cudnnDropoutStruct, + &cudnnCreateDropoutDescriptor, + &cudnnDestroyDropoutDescriptor> { + at::Tensor state; + + // Initialize a dropout descriptor's RNG state. + // WARNING: This function is very expensive, avoid calling this function! + void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) { + TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout"); + size_t state_size = 0; + AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size)); + AT_ASSERT(options.device().type() == kCUDA); + AT_ASSERT(options.dtype() == kByte); + state = at::empty({static_cast(state_size)}, options); + AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed)); + } + + // Restore a dropout descriptor given a dropout probability and existing RNG state. + void set(cudnnHandle_t handle, float dropout, const at::Tensor& state) { + TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout"); + void *state_ptr = state.data_ptr(); + size_t state_size = state.size(0); + // NB: The seed doesn't actually matter, so we give a dummy value + AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */)); + } + + // Restore a dropout descriptor corresponding to no dropout + void set_no_dropout(cudnnHandle_t handle) { + // NB: seed doesn't matter when dropout = 0, because no random number + // initialization actually takes place when there is no dropout. + // NB: Empirically, cudnnSetDropoutDescriptor is cheap when + // dropout == 0 + AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */)); + } +}; + +struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor< + cudnnRNNStruct, + &cudnnCreateRNNDescriptor, + &cudnnDestroyRNNDescriptor> { + DropoutDescriptor dropout_desc_; + void set(cudnnHandle_t handle, +#ifdef USE_CUDNN_RNN_V8_API + int input_size, + bool packed, +#endif + int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc, + cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, + cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) { + dropout_desc_ = std::move(dropout_desc); +#ifndef USE_CUDNN_RNN_V8_API + AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6( + handle, + mut_desc(), + hidden_size, + num_layers, + dropout_desc_.desc(), + input_mode, + bidirectional, + mode, + algo, + datatype)); + if (proj_size != 0) { + AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers( + handle, + /*rnnDesc=*/mut_desc(), + /*recProjSize=*/proj_size, + /*outProjSize=*/0)); + } + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 7) { + if (input_type == CUDNN_DATA_HALF) { + cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH); + } + else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) { + cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH); + } + else { + // Technically, as the default it's not necessary to explicitly + // set this. + cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH); + } + } +#else + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + auto math_type = CUDNN_DEFAULT_MATH; + if (prop->major >= 7) { + if (input_type == CUDNN_DATA_HALF) { + math_type = CUDNN_TENSOR_OP_MATH; + } else if (!allow_tf32) { + math_type = CUDNN_FMA_MATH; + } + } + AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v8( + mut_desc(), + algo, + mode, + CUDNN_RNN_DOUBLE_BIAS, + bidirectional, + input_mode, + input_type, + datatype, + math_type, + input_size, + hidden_size, + proj_size ? proj_size : hidden_size, + num_layers, + dropout_desc_.desc(), + packed ? CUDNN_RNN_PADDED_IO_DISABLED : CUDNN_RNN_PADDED_IO_ENABLED)); +#endif + } +}; + +struct TORCH_CUDA_CPP_API CTCLossDescriptor + : public Descriptor< + cudnnCTCLossStruct, + &cudnnCreateCTCLossDescriptor, + &cudnnDestroyCTCLossDescriptor> { + void set(cudnnDataType_t datatype) { + AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype)); + } + void setEx( + cudnnDataType_t datatype, + cudnnLossNormalizationMode_t normMode, + cudnnNanPropagation_t gradMode) { + AT_CUDNN_CHECK( + cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode)); + } + void set_v8_v9( + cudnnDataType_t datatype, + cudnnLossNormalizationMode_t normMode, + cudnnNanPropagation_t gradMode, + int maxLabelLength) { +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000 + auto gradModev9 = CUDNN_CTC_ZERO_OOB_GRADIENTS; + if (gradMode == cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN) { + gradModev9 = CUDNN_CTC_SKIP_OOB_GRADIENTS; + } + AT_CUDNN_CHECK( + cudnnSetCTCLossDescriptor_v9(mut_desc(), datatype, normMode, gradModev9, maxLabelLength)); +#else + AT_CUDNN_CHECK( + cudnnSetCTCLossDescriptor_v8(mut_desc(), datatype, normMode, gradMode, maxLabelLength)); +#endif + } + +}; + +struct TORCH_CUDA_CPP_API ActivationDescriptor + : public Descriptor< + cudnnActivationStruct, + &cudnnCreateActivationDescriptor, + &cudnnDestroyActivationDescriptor> { + void set(cudnnActivationMode_t mode) { + AT_ASSERT( + mode == CUDNN_ACTIVATION_RELU, + "TODO: support more cuDNN activation modes"); + AT_CUDNN_CHECK(cudnnSetActivationDescriptor( + mut_desc(), + mode, + cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, + std::numeric_limits::max())); + } +}; + +union Constant +{ + float f; + double d; + Constant(cudnnDataType_t dataType, double value) { + if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) { + f = static_cast(value); + } else { + d = value; + } + } +}; + +} // namespace diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Handle.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Handle.h new file mode 100644 index 0000000000000000000000000000000000000000..a049ee9ad8fa47a658cbb70337ee2c4c8461abda --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Handle.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API cudnnHandle_t getCudnnHandle(); +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Handles.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Handles.h new file mode 100644 index 0000000000000000000000000000000000000000..5b9a081f0c11b57b093891e0dd2adbd969a79f96 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Handles.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Types.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Types.h new file mode 100644 index 0000000000000000000000000000000000000000..202a72e4761aecb485780c9d6b1e2129b9ac31cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Types.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API cudnnDataType_t +getCudnnDataTypeFromScalarType(const at::ScalarType dtype); +cudnnDataType_t getCudnnDataType(const at::Tensor& tensor); + +int64_t cudnn_version(); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a66df02c2001891382319a84ded65a42f85a4cea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/Utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +// cuDNN has a buggy check for tensor being contiguous (that is, it does +// not ignore stride for dimension that is equal to 0). This function +// makes tensors which have zero stride contiguous, by setting the +// strides to 1 as cuDNN likes. +inline Tensor contiguousIfZeroInStrides(const Tensor& t) { + for (auto s : t.strides()) { + if (s == 0) + return t.contiguous(); + } + return t; +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..52ec1cddcb7dc1a0263a65da5ff2f33a36805af0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#define STRINGIFY(x) #x +#define STRING(x) STRINGIFY(x) + +#if CUDNN_MAJOR < 8 || (CUDNN_MAJOR == 8 && CUDNN_MINOR < 5) +#pragma message("CuDNN v" STRING( \ + CUDNN_MAJOR) " found, but need at least CuDNN v8. You can get the latest version of CuDNN from https://developer.nvidia.com/cudnn or disable CuDNN with USE_CUDNN=0") +#pragma message "We strongly encourage you to move to 8.5 and above." +#pragma message "This message is intended to annoy you enough to update." +#endif + +#undef STRINGIFY +#undef STRING diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..5aa38635430d0ceacec92ec4e79b1ff970cf20d8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h @@ -0,0 +1,96 @@ +#pragma once + +#include + +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") + +namespace at { + +// AcceleratorHooksInterface is a shared interface provided by all +// accelerators to allow generic code. +// This inferface is hook-based as it corresponds to all the functions +// that are going to be called in a generic way from the CPU code. + +struct TORCH_API AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + virtual ~AcceleratorHooksInterface() = default; + + // Whether this backend was enabled at compilation time. + // This function should NEVER throw. + virtual bool isBuilt() const { + return false; + } + + // Whether this backend can be used at runtime, meaning it was built, + // its runtime dependencies are available (driver) and at least one + // supported device can be used. + // This function should NEVER throw. This function should NOT initialize the context + // on any device (result of hasPrimaryContext below should not change). + // While it is acceptable for this function to poison fork, it is + // recommended to avoid doing so whenever possible. + virtual bool isAvailable() const { + return false; + } + + // Whether the device at device_index is fully initialized or not. + virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0; + + virtual void init() const { + TORCH_CHECK(false, "Backend doesn`t support init()"); + } + + virtual DeviceIndex deviceCount() const { + return 0; + } + + virtual void setCurrentDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support setCurrentDevice()"); + } + + virtual DeviceIndex getCurrentDevice() const { + TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()"); + return -1; + } + + virtual DeviceIndex exchangeDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support exchangeDevice()"); + return -1; + } + + virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const { + TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()"); + return -1; + } + + virtual bool isPinnedPtr(const void* data) const { + return false; + } + + virtual Allocator* getPinnedMemoryAllocator() const { + TORCH_CHECK(false, "Backend doesn't support getPinnedMemoryAllocator()"); + return nullptr; + } + + virtual Device getDeviceFromPtr(void* data) const { + TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()"); + } + + virtual const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK(false, "Backend doesn`t support getDefaultGenerator()"); + } + + virtual Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK(false, "Backend doesn`t support getNewGenerator()"); + } +}; + +} // namespace at + +C10_DIAGNOSTIC_POP() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..f99e03d156c9b5364aa3a05d5c5d77627a837c01 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h @@ -0,0 +1,224 @@ +#pragma once + +#include +#include +#include + +#include + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +// Forward-declares at::cuda::NVRTC +namespace cuda { +struct NVRTC; +} // namespace cuda + +#ifdef _MSC_VER +constexpr const char* CUDA_HELP = + "PyTorch splits its backend into two shared libraries: a CPU library " + "and a CUDA library; this error has occurred because you are trying " + "to use some CUDA functionality, but the CUDA library has not been " + "loaded by the dynamic linker for some reason. The CUDA library MUST " + "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " + "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ " + "in your link arguments; many dynamic linkers will delete dynamic library " + "dependencies if you don't depend on any of their symbols. You can check " + "if this has occurred by using link on your binary to see if there is a " + "dependency on *_cuda.dll library."; +#else +constexpr const char* CUDA_HELP = + "PyTorch splits its backend into two shared libraries: a CPU library " + "and a CUDA library; this error has occurred because you are trying " + "to use some CUDA functionality, but the CUDA library has not been " + "loaded by the dynamic linker for some reason. The CUDA library MUST " + "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " + "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many " + "dynamic linkers will delete dynamic library dependencies if you don't " + "depend on any of their symbols. You can check if this has occurred by " + "using ldd on your binary to see if there is a dependency on *_cuda.so " + "library."; +#endif + +// The CUDAHooksInterface is an omnibus interface for any CUDA functionality +// which we may want to call into from CPU code (and thus must be dynamically +// dispatched, to allow for separate compilation of CUDA code). How do I +// decide if a function should live in this class? There are two tests: +// +// 1. Does the *implementation* of this function require linking against +// CUDA libraries? +// +// 2. Is this function *called* from non-CUDA ATen code? +// +// (2) should filter out many ostensible use-cases, since many times a CUDA +// function provided by ATen is only really ever used by actual CUDA code. +// +// TODO: Consider putting the stub definitions in another class, so that one +// never forgets to implement each virtual function in the real implementation +// in CUDAHooks. This probably doesn't buy us much though. +struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + ~CUDAHooksInterface() override = default; + + // Initialize THCState and, transitively, the CUDA state + void init() const override { + TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK( + false, + "Cannot get default CUDA generator without ATen_cuda library. ", + CUDA_HELP); + } + + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK( + false, + "Cannot get CUDA generator without ATen_cuda library. ", + CUDA_HELP); + } + + Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP); + } + + bool isPinnedPtr(const void* /*data*/) const override { + return false; + } + + virtual bool hasCUDA() const { + return false; + } + + virtual bool hasCUDART() const { + return false; + } + + virtual bool hasMAGMA() const { + return false; + } + + virtual bool hasCuDNN() const { + return false; + } + + virtual bool hasCuSOLVER() const { + return false; + } + + virtual bool hasCuBLASLt() const { + return false; + } + + virtual bool hasROCM() const { + return false; + } + + virtual const at::cuda::NVRTC& nvrtc() const { + TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP); + } + + virtual DeviceIndex current_device() const { + return -1; + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP); + } + + virtual Allocator* getCUDADeviceAllocator() const { + TORCH_CHECK(false, "CUDADeviceAllocator requires CUDA. ", CUDA_HELP); + } + + virtual bool compiledWithCuDNN() const { + return false; + } + + virtual bool compiledWithMIOpen() const { + return false; + } + + virtual bool supportsDilatedConvolutionWithCuDNN() const { + return false; + } + + virtual bool supportsDepthwiseConvolutionWithCuDNN() const { + return false; + } + + virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const { + return false; + } + + virtual long versionCuDNN() const { + TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); + } + + virtual long versionMIOpen() const { + TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP); + } + + virtual long versionCUDART() const { + TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP); + } + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP); + } + + virtual double batchnormMinEpsilonCuDNN() const { + TORCH_CHECK(false, + "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP); + } + + virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP); + } + + virtual int getNumGPUs() const { + return 0; + } + +#ifdef USE_ROCM + virtual bool isGPUArch(const std::vector& /*archs*/, DeviceIndex = -1 /*device_index*/) const { + TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP); + } +#endif + + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API CUDAHooksArgs {}; + +TORCH_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs); +#define REGISTER_CUDA_HOOKS(clsname) \ + C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const CUDAHooksInterface& getCUDAHooks(); +} // namespace detail +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/FunctionTraits.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/FunctionTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..c9212cb75a4d48e2db725e2d510adf2fdbbb5d98 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/FunctionTraits.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include + +// Modified from https://stackoverflow.com/questions/7943525/is-it-possible-to-figure-out-the-parameter-type-and-return-type-of-a-lambda + +// Fallback, anything with an operator() +template +struct function_traits : public function_traits { +}; + +// Pointers to class members that are themselves functors. +// For example, in the following code: +// template +// struct S { +// func_t f; +// }; +// template +// S make_s(func_t f) { +// return S { .f = f }; +// } +// +// auto s = make_s([] (int, float) -> double { /* ... */ }); +// +// function_traits traits; +template +struct function_traits : public function_traits { +}; + +// Const class member functions +template +struct function_traits : public function_traits { +}; + +// Reference types +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; + +// Free functions +template +struct function_traits { + // arity is the number of arguments. + enum { arity = sizeof...(Args) }; + + using ArgsTuple = std::tuple; + using result_type = ReturnType; + + template + struct arg + { + using type = typename std::tuple_element>::type; + // the i-th argument is equivalent to the i-th tuple element of a tuple + // composed of those arguments. + }; +}; + +template +struct nullary_function_traits { + using traits = function_traits; + using result_type = typename traits::result_type; +}; + +template +struct unary_function_traits { + using traits = function_traits; + using result_type = typename traits::result_type; + using arg1_t = typename traits::template arg<0>::type; +}; + +template +struct binary_function_traits { + using traits = function_traits; + using result_type = typename traits::result_type; + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; +}; + + +// Traits for calling with c10::guts::invoke, where member_functions have a first argument of ClassType +template +struct invoke_traits : public function_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : + public function_traits { +}; + +template +struct invoke_traits : + public function_traits { +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/HIPHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/HIPHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..0c68f96833f70cbc8ad4596a79b9cf0fb94fff3d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/HIPHooksInterface.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +// The HIPHooksInterface is an omnibus interface for any HIP functionality +// which we may want to call into from CPU code (and thus must be dynamically +// dispatched, to allow for separate compilation of HIP code). See +// CUDAHooksInterface for more detailed motivation. +struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + ~HIPHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); + } + + virtual bool hasHIP() const { + return false; + } + + virtual c10::DeviceIndex current_device() const { + return -1; + } + + bool isPinnedPtr(const void* /*data*/ ) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Pinned memory requires HIP."); + } + + virtual int getNumGPUs() const { + return 0; + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/ ) const override { + TORCH_CHECK(false, "Cannot check primary context without ATen_hip library."); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API HIPHooksArgs {}; + +TORCH_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs); +#define REGISTER_HIP_HOOKS(clsname) \ + C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const HIPHooksInterface& getHIPHooks(); + +} // namespace detail +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/HPUHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/HPUHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..8cf9502a7e1b90c369329985d88d96c36f640582 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/HPUHooksInterface.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace at { + +struct TORCH_API HPUHooksInterface : AcceleratorHooksInterface { + ~HPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize HPU without HPU backend"); + } + + virtual bool hasHPU() const { + return false; + } + + Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK( + false, "Cannot get device of pointer on HPU without HPU backend"); + } + + bool isPinnedPtr(const void*) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK( + false, + "You should register `HPUHooksInterface` for HPU before call `getPinnedMemoryAllocator`."); + } + + bool hasPrimaryContext( + [[maybe_unused]] DeviceIndex device_index) const override { + TORCH_CHECK( + false, + "You should register `HPUHooksInterface` for HPU before call `hasPrimaryContext`."); + } +}; + +struct TORCH_API HPUHooksArgs {}; + +TORCH_DECLARE_REGISTRY(HPUHooksRegistry, HPUHooksInterface, HPUHooksArgs); +#define REGISTER_HPU_HOOKS(clsname) \ + C10_REGISTER_CLASS(HPUHooksRegistry, clsname, clsname) + +namespace detail { + +TORCH_API const at::HPUHooksInterface& getHPUHooks(); + +} // namespace detail +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/IPUHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/IPUHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..1be3e0873d581d7c36c29f90a70bc10f7d55e71f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/IPUHooksInterface.h @@ -0,0 +1,43 @@ +#pragma once + +#include + +#include +#include +#include + +namespace at { + +struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface { + ~IPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + return false; + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } + + Generator getNewGenerator( + DeviceIndex /*device_index*/ = -1) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } +}; + +struct TORCH_API IPUHooksArgs {}; + +TORCH_DECLARE_REGISTRY(IPUHooksRegistry, IPUHooksInterface, IPUHooksArgs); +#define REGISTER_IPU_HOOKS(clsname) \ + C10_REGISTER_CLASS(IPUHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const IPUHooksInterface& getIPUHooks(); +} // namespace detail +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MAIAHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MAIAHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..d747bc6868445790518b958b1f0b11474690b985 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MAIAHooksInterface.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +struct TORCH_API MAIAHooksInterface : AcceleratorHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + ~MAIAHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/) const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + return false; + } + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed MAIA version information."); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API MAIAHooksArgs {}; + +TORCH_DECLARE_REGISTRY(MAIAHooksRegistry, MAIAHooksInterface, MAIAHooksArgs); +#define REGISTER_MAIA_HOOKS(clsname) \ + C10_REGISTER_CLASS(MAIAHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MAIAHooksInterface& getMAIAHooks(); +} // namespace detail + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MPSHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MPSHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..1a16072fd95d5e22c817c03a69be0aa303a8fa7b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MPSHooksInterface.h @@ -0,0 +1,125 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include + +#include +#include +#include + +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") +namespace at { + +struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { + // this fails the implementation if MPSHooks functions are called, but + // MPS backend is not present. + #define FAIL_MPSHOOKS_FUNC(func) \ + TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend."); + + ~MPSHooksInterface() override = default; + + // Initialize the MPS library state + void init() const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual bool hasMPS() const { + return false; + } + virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index) const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual Allocator* getMPSDeviceAllocator() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void deviceSynchronize() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void commitStream() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void* getCommandBuffer() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void* getDispatchQueue() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void emptyCache() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual size_t getCurrentAllocatedMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual size_t getDriverAllocatedMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual size_t getRecommendedMaxMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void setMemoryFraction(double /*ratio*/) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void profilerStopTrace() const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual uint32_t acquireEvent(bool enable_timing) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + Device getDeviceFromPtr(void* data) const override { + TORCH_CHECK(false, "Cannot get device of pointer on MPS without ATen_mps library. "); + } + virtual void releaseEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void recordEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void waitForEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual void synchronizeEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual bool queryEvent(uint32_t event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const { + FAIL_MPSHOOKS_FUNC(__func__); + } + bool hasPrimaryContext(DeviceIndex device_index) const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + bool isPinnedPtr(const void* data) const override { + return false; + } + Allocator* getPinnedMemoryAllocator() const override { + FAIL_MPSHOOKS_FUNC(__func__); + } + #undef FAIL_MPSHOOKS_FUNC +}; + +struct TORCH_API MPSHooksArgs {}; + +TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs); +#define REGISTER_MPS_HOOKS(clsname) \ + C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MPSHooksInterface& getMPSHooks(); + +} // namespace detail +} // namespace at +C10_DIAGNOSTIC_POP() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..fb8ed6fb232266fd68cc0a5b387ea826de60955f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h @@ -0,0 +1,164 @@ +#pragma once + +#include +#include + +#include +#include + +#include + +#include +#include + +#include +namespace at { +class Context; +} + +namespace at { +constexpr const char* MTIA_HELP = + "The MTIA backend requires MTIA extension for PyTorch;" + "this error has occurred because you are trying " + "to use some MTIA's functionality without MTIA extension included."; + +struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { +// this fails the implementation if MTIAHooks functions are called, but +// MTIA backend is not present. +#define FAIL_MTIAHOOKS_FUNC(func) \ + TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend."); + + ~MTIAHooksInterface() override = default; + + void init() const override { + // Avoid logging here, since MTIA needs init devices first then it will know + // how many devices are available. Make it as no-op if mtia extension is not + // dynamically loaded. + return; + } + + virtual bool hasMTIA() const { + return false; + } + + DeviceIndex deviceCount() const override { + return 0; + } + + virtual void deviceSynchronize(c10::DeviceIndex /*device_index*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual std::string showConfig() const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + bool hasPrimaryContext(DeviceIndex /*device_index*/) const override { + return false; + } + + void setCurrentDevice(DeviceIndex /*device*/) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + DeviceIndex getCurrentDevice() const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + DeviceIndex exchangeDevice(DeviceIndex /*device*/) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + DeviceIndex maybeExchangeDevice(DeviceIndex /*device*/) const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual c10::Stream getCurrentStream(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); + } + + virtual int64_t getCurrentRawStream(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual c10::Stream getDefaultStream(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); + } + + virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + bool isPinnedPtr(const void* /*data*/) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual PyObject* memoryStats(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual PyObject* getDeviceCapability(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual PyObject* getDeviceProperties(DeviceIndex device) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual void emptyCache() const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + + virtual void recordMemoryHistory( + const std::optional& /*enabled*/, + const std::string& /*stacks*/, + size_t /*max_entries*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual PyObject* memorySnapshot(const std::optional& local_path) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } + + virtual DeviceIndex getDeviceCount() const { + FAIL_MTIAHOOKS_FUNC(__func__); + return 0; + } + + virtual void resetPeakMemoryStats(DeviceIndex /*device*/) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual void attachOutOfMemoryObserver(PyObject* observer) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return; + } +}; + +struct TORCH_API MTIAHooksArgs {}; + +TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs); +#define REGISTER_MTIA_HOOKS(clsname) \ + C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MTIAHooksInterface& getMTIAHooks(); +TORCH_API bool isMTIAHooksBuilt(); +} // namespace detail +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..ddcbdae84b94db35b37c093a4e3c0172aacdbf63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h @@ -0,0 +1,89 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") + +namespace at { + +struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { +#define FAIL_PRIVATEUSE1HOOKS_FUNC(func) \ + TORCH_CHECK_NOT_IMPLEMENTED( \ + false, \ + "You should register `PrivateUse1HooksInterface`", \ + "by `RegisterPrivateUse1HooksInterface` and implement `", \ + func, \ + "` at the same time for PrivateUse1."); + + ~PrivateUse1HooksInterface() override = default; + + bool isBuilt() const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + bool isAvailable() const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + // TODO(FFFrog): Perserved for BC and will be removed in the future. + if (at::GetGeneratorPrivate().has_value()) + return at::GetGeneratorForPrivateuse1(device_index); + + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + at::Device getDeviceFromPtr(void* data) const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + bool isPinnedPtr(const void* data) const override { + return false; + } + + Allocator* getPinnedMemoryAllocator() const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + + void init() const override {} + virtual void resizePrivateUse1Bytes( + const c10::Storage& storage, + size_t newsize) const { + FAIL_PRIVATEUSE1HOOKS_FUNC(__func__); + } + +#undef FAIL_PRIVATEUSE1HOOKS_FUNC +}; + +struct TORCH_API PrivateUse1HooksArgs {}; + +TORCH_API void RegisterPrivateUse1HooksInterface( + at::PrivateUse1HooksInterface* hook_); + +TORCH_API bool isPrivateUse1HooksRegistered(); + +namespace detail { + +TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks(); + +} // namespace detail + +} // namespace at + +C10_DIAGNOSTIC_POP() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/XPUHooksInterface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/XPUHooksInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..5d31acbce170a61751901c91e6b1c346b611e83d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/detail/XPUHooksInterface.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include + +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") + +namespace at { + +struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ + ~XPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize XPU without ATen_xpu library."); + } + + virtual bool hasXPU() const { + return false; + } + + virtual std::string showConfig() const { + TORCH_CHECK( + false, + "Cannot query detailed XPU version without ATen_xpu library."); + } + + virtual int32_t getGlobalIdxFromDevice(const Device& device) const { + TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library."); + } + + const Generator& getDefaultGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK( + false, "Cannot get default XPU generator without ATen_xpu library."); + } + + Generator getNewGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const override { + TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library."); + } + + virtual DeviceIndex getNumGPUs() const { + return 0; + } + + virtual DeviceIndex current_device() const { + TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library."); + } + + Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library."); + } + + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library."); + } + + Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Cannot get XPU pinned memory allocator without ATen_xpu library."); + } + + bool isPinnedPtr(const void* data) const override { + return false; + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot query primary context without ATen_xpu library."); + } +}; + +struct TORCH_API XPUHooksArgs {}; + +TORCH_DECLARE_REGISTRY(XPUHooksRegistry, XPUHooksInterface, XPUHooksArgs); +#define REGISTER_XPU_HOOKS(clsname) \ + C10_REGISTER_CLASS(XPUHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const XPUHooksInterface& getXPUHooks(); +} // namespace detail +} // namespace at +C10_DIAGNOSTIC_POP() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h new file mode 100644 index 0000000000000000000000000000000000000000..ceca44832179dbd3102402696a972de11ca88016 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::functorch { + +// These are the interpreters for our AD transforms +// (grad, vjp and jvp). +// See NOTE: [functorch interpreter stack] for more details. + +struct TORCH_API GradInterpreterPtr { + explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + bool prevGradMode() const { + return std::get(base_->meta()).prevGradMode_; + } + Tensor lift(const Tensor& tensor) const; + private: + const Interpreter* base_; +}; + +struct TORCH_API JvpInterpreterPtr { + explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + bool prevFwdGradMode() const { + return std::get(base_->meta()).prevFwdGradMode_; + } + Tensor lift(const Tensor& tensor) const; + private: + const Interpreter* base_; +}; + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..70fbf3135a3ca8e32b03ce6481dc1781d779cb0c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h @@ -0,0 +1,481 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// This file contains helper functions for batching rules. + +namespace at::functorch { + +TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x); +TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x); + +TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x); + +Tensor moveBatchDimToFront(Tensor tensor, std::optional maybe_batch_dim); +int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional maybe_batch_dim); +int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional maybe_batch_dim); +std::optional valIfNonempty(std::optional maybe_empty, int64_t new_val); +int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim); +VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims); + +void vmapIncompatibleInplaceError(const char* schema_name); + +Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional has_bdim, int64_t logical_rank); + +void check_randomness(RandomnessType randomness); +void check_randomness(RandomnessType randomness, bool any_tensor_bdim); + +inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) { + if (has_bdim) { + return tensor; + } + const auto sizes = tensor.sym_sizes(); + SymDimVector expanded_shape; + expanded_shape.reserve(sizes.size()); + expanded_shape.emplace_back(std::move(batch_size)); + expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end()); + return tensor.expand_symint(expanded_shape); +} + +#define VMAP_SUPPORT(op, batch_rule) \ + m.impl(#op, op ## _generated_plumbing); + +#define VMAP_SUPPORT2(op, overload, batch_rule) \ + m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing); + +#define OP_DECOMPOSE(op) m.impl(#op, static_cast(native::op)); +#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast(native::op)); + +// DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain +template +struct BasicUnaryBatchRuleHelper; + +template +struct BasicUnaryBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, + std::optional batch_dim, + T... extra_args) { + return std::make_tuple(Func(tensor, std::forward(extra_args)...), batch_dim); + } +}; + +// USAGE: BASIC_UNARY_BATCH_RULE(at::sin) +// INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin) +// It is important that this macro is not passed a function pointer!! +#define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\ + BasicUnaryBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +#define UNARY_POINTWISE(op) \ + VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op))); + +template +struct VariadicBdimsBatchRuleHelper; + +template +struct VariadicBdimsBatchRuleHelper> { + static std::tuple> apply( + const Tensor& tensor, + std::optional batch_dim, + T... extra_args) { + auto tensor_ = moveBatchDimToFront(tensor, batch_dim); + return std::make_tuple(Func(tensor_, std::forward(extra_args)...), 0); + } +}; + +// USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse) +// INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse) +// It is important that this macro is not passed a function pointer!! +#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\ + VariadicBdimsBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + +#define VARIADIC_BDIMS(op) \ + VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op))); + +#define VARIADIC_BDIMS2(op, overload) \ + VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload))); + +template +void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule"); + + int64_t cur_level = maybe_layer->layerId(); + + auto orig_arguments = torch::jit::last(*stack, num_arguments); + if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + auto arguments = torch::jit::pop(*stack, num_arguments); + std::vector>> tensor_inputs; + std::vector tensor_pos; + for (const auto idx : c10::irange(0, num_arguments)) { + const auto& ivalue = arguments[idx]; + if (ivalue.isTensor()) { + auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level); + tensor_inputs.emplace_back(std::move(tensor_value), tensor_bdim); + tensor_pos.push_back(static_cast(idx)); + } + } + Func(tensor_inputs); + + size_t tensor_idx = 0; + TORCH_INTERNAL_ASSERT(!tensor_pos.empty()); + for (const auto arg_idx : c10::irange(0, num_arguments)) { + if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) { + torch::jit::push(stack, arguments[arg_idx]); + } else { + TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size()); + torch::jit::push(stack, tensor_inputs[tensor_idx].first); + tensor_idx++; + } + } + + op.callBoxed(stack); + const auto returns = torch::jit::pop(*stack, num_returns); + for (const auto& ret : returns) { + if (ret.isTensor()) { + torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level)); + } else { + TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values"); + } + } +} + +inline void handle_pointwise_ops(std::vector>> &tensor_inputs) { + int64_t out_logical_rank = 0; + for (auto& tensor_input : tensor_inputs) { + int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second); + out_logical_rank = std::max(out_logical_rank, cur_logical_rank); + } + for (auto& tensor_input: tensor_inputs) { + tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second); + tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank); + } +} + +#define POINTWISE_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +#define POINTWISE_BOXED2(op, overload) \ + m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction>()); + +inline void handle_variadic_bdims(std::vector>> &tensor_inputs) { + for (auto & tensor_input : tensor_inputs) { + tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second); + } +} + +#define VARIADIC_BDIMS_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +using UnpackedBatchedTensor = std::tuple>; + +inline void find_and_unpack_tensors( + const torch::jit::Stack* stack, + int64_t num_args, + int64_t cur_level, + SmallVector* tensors, + SmallVector* tensors_pos, + int64_t* batch_size) { + + int64_t computed_batch_size = -1; + int64_t args_begin = static_cast(stack->size()) - num_args; + + for (const auto idx : c10::irange(0, num_args)) { + const auto& ivalue = (*stack)[args_begin + idx]; + if (!ivalue.isTensor()) { + continue; + } + auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level); + const auto& [tensor_value, tensor_bdim] = unpacked; + if (tensor_bdim.has_value()) { + auto candidate_batch_size = tensor_value.size(*tensor_bdim); + if (computed_batch_size == -1) { + computed_batch_size = candidate_batch_size; + } + TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size); + } + + tensors->push_back(std::move(unpacked)); + tensors_pos->push_back(idx); + } + TORCH_INTERNAL_ASSERT(computed_batch_size > -1); + *batch_size = computed_batch_size; +} + +inline void boxed_existing_bdim_all_batch_rule( + const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = static_cast(schema.arguments().size()); + + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + const auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule"); + + const auto arguments = torch::jit::last(stack, num_arguments); + if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + int64_t args_begin = static_cast(stack->size()) - num_arguments; + SmallVector tensor_inputs; + SmallVector tensor_pos; + int64_t batch_size = 0; + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + int64_t cur_level = maybe_layer->layerId(); + + find_and_unpack_tensors( + stack, num_arguments, cur_level, + &tensor_inputs, &tensor_pos, &batch_size); + + // for each tensor, ensure it has a bdim and reshape it. + for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) { + const auto& [value, bdim] = tensor_inputs[tensor_idx]; + auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size); + (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(bdim.value_or(0), 0, value_); + } + + op.callBoxed(stack); + + for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) { + const auto& ret = (*stack)[idx]; + TORCH_INTERNAL_ASSERT(ret.isTensor(), + "This boxed batching rule does not currently support ops that return non-tensor values"); + (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level); + } +} + +// Use when all tensors arguments accept one (normal) batch dim. +// This batching rule expands the batch dim on all Tensors, reshapes it into +// dim 0, calls the op, and then reshapes the batch dim out of dim 0. +// This is not the most efficient thing; if there are alternatives, plese try +// to use them. Use this only as a last resort. +#define EXISTING_BDIM_ALL_BOXED(op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction()); + +template +inline void boxed_all_tensors_have_optional_bdim( + const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + const auto num_arguments = schema.arguments().size(); + + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim"); + int64_t cur_level = maybe_layer->layerId(); + + const auto arguments = torch::jit::last(stack, num_arguments); + if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) { + op.callBoxed(stack); + return; + } + + int64_t args_begin = static_cast(stack->size() - num_arguments); + SmallVector tensor_inputs; + SmallVector tensor_pos; + int64_t batch_size = 0; + + find_and_unpack_tensors( + stack, static_cast(num_arguments), cur_level, + &tensor_inputs, &tensor_pos, &batch_size); + + std::optional is_no_batch_dim_case; + + for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) { + const auto& value = std::get<0>(tensor_inputs[tensor_idx]); + auto bdim = std::get<1>(tensor_inputs[tensor_idx]); + const auto logical_rank = rankWithoutBatchDim(value, bdim); + + if (!is_no_batch_dim_case.has_value()) { + is_no_batch_dim_case = (logical_rank == feature_rank); + } + auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size); + if (!bdim.has_value()) { + bdim = 0; + } + if (*is_no_batch_dim_case) { + TORCH_INTERNAL_ASSERT(logical_rank == feature_rank); + value_ = moveBatchDimToFront(value_, bdim); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_); + continue; + } + TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1); + value_ = reshape_dim_into(*bdim, 0, value_); + if (tensor_idx == contig_tensor_index) { + value_ = value_.contiguous(); + } + (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_); + } + + op.callBoxed(stack); + + for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) { + const auto& ret = (*stack)[idx]; + TORCH_INTERNAL_ASSERT(ret.isTensor(), + "This boxed batching rule does not currently support ops that return non-tensor values"); + if (*is_no_batch_dim_case) { + (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level); + } else { + (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level); + } + } +} + +// Useful for many NN operators. +// The operator must satisfy the following: +// - All arguments must accept an optional batch dim. +// - All arguments must be the same rank +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \ + m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); + +#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \ + m.impl(#op, \ + torch::CppFunction::makeFromBoxedFunction<\ + boxed_all_tensors_have_optional_bdim<\ + feature_rank, \ + contig_tensor_index>\ + >()); + +template +struct ExistingBdimBatchRuleHelper; + +template +struct ExistingBdimBatchRuleHelper> { + static std::tuple> apply( + const Tensor& self, + std::optional self_bdim, + T... extra_args) { + auto self_ = reshape_dim_into(*self_bdim, 0, self); + auto out = Func(self_, std::forward(extra_args)...); + return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0); + } +}; + +// USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse) +// INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse) +// It is important that this macro is not passed a function pointer!! +#define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\ + ExistingBdimBatchRuleHelper<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + + +#define EXISTING_BDIM(op) \ + VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op))); + +#define EXISTING_BDIM2(op, overload) \ + VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload))); + +#define INVOKE(object,ptrToMember) ((object).*(ptrToMember)) + + +template +Tensor& unary_inplace_batch_rule(Tensor& self, std::optional, ExtraArgs... extra_args) { + INVOKE(self, Method)(std::forward(extra_args)...); + return self; +} + +inline int64_t get_bdim_size4( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim, + const Tensor& c_value, std::optional c_bdim, + const Tensor& d_value, std::optional d_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + if (c_bdim) + return c_value.size(*c_bdim); + if (d_bdim) + return d_value.size(*d_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline int64_t get_bdim_size3( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim, + const Tensor& c_value, std::optional c_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + if (c_bdim) + return c_value.size(*c_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline int64_t get_bdim_size2( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim) { + if (a_bdim) + return a_value.size(*a_bdim); + if (b_bdim) + return b_value.size(*b_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +inline c10::SymInt get_bdim_size2_symint( + const Tensor& a_value, std::optional a_bdim, + const Tensor& b_value, std::optional b_bdim) { + if (a_bdim) + return a_value.sym_size(*a_bdim); + if (b_bdim) + return b_value.sym_size(*b_bdim); + TORCH_INTERNAL_ASSERT(false); +} + +// [start, start + 1, ..., stop - 1] +inline VmapDimVector range(int64_t start, int64_t stop) { + TORCH_INTERNAL_ASSERT(stop >= start); + VmapDimVector dims; + dims.reserve(stop - start); + for (int64_t i = start; i < stop; i++) { + dims.emplace_back(i); + } + return dims; +} +std::tuple _binary_pointwise_helper( + const Tensor& tensor, std::optional tensor_batch_dim, const Tensor& other, std::optional other_batch_dim, + bool do_type_promotion=true); + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h new file mode 100644 index 0000000000000000000000000000000000000000..f76191221a1e2b1a4d055c8a3e68945791c259df --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h @@ -0,0 +1,81 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +namespace at::functorch { + +// This file contains code for the vmap fallback (also known as the +// BatchedTensor fallback or the Batched fallback). This code runs +// when an operation doesn't have a batching rule implemented. + +// If an operator doesn't have a batching rule implemented then we fallback +// to this implementation. The fallback doesn't work on out= variants or +// view operations; that is, it works for out-of-place operations and +// in-place non-view operations. +// +// For out-of-place operations, the fallback effectively takes all of the +// BatchedTensors in `stack`, slices them, and runs `op` on all of the +// corresponding slices to produce slices of the outputs. The output slices +// then get `torch.stack`ed to create the +// final returns. +// +// The performance of the fallback is not very good because it introduces an +// extra copy from stacking the sliced outputs. Because of this, we prefer to +// write batching rules for operators whenever possible. +void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); +void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +// The vmap fallback emits a warning by default, but it may be disabled if +// the user finds it to be too annoying. +TORCH_API bool isVmapFallbackWarningEnabled(); +TORCH_API void setVmapFallbackWarningEnabled(bool enabled); + +// Used for testing. The vmap fallback is enabled by default. When it is disabled, +// it raises an error. +TORCH_API bool isVmapFallbackEnabled(); +TORCH_API void setVmapFallbackEnabled(bool enabled); + +template A vector_to_result(const std::vector& buffer) { + return buffer[0].to(); +} +template std::tuple vector_to_result(const std::vector& buffer) { + return std::make_tuple(buffer[0].to(), buffer[1].to()); +} +template std::tuple vector_to_result(const std::vector& buffer) { + return std::make_tuple(buffer[0].to(), buffer[1].to(), buffer[2].to()); +} + +// slow_fallback is a way to call the vmap fallback inside some boxed kernel. +// There is probably some better way to metaprogram this. +template +Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + +template +std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + +template +std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) { + std::vector stack(args.begin(), args.end()); + batchedTensorForLoopFallback(op, &stack); + return vector_to_result(stack); +} + + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..ce3d290084117423118c24b660f9e324abadc10e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h @@ -0,0 +1,170 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include +#include + +namespace at::functorch { + +using Tensor = at::Tensor; + +// We assume this in a few other places in the codebase, +// but there isn't a centralized definition. +constexpr int64_t kVmapMaxTensorDims = 64; + +// The valid vmap levels range from [0, 64). This effectively means that we +// support a maximum of 64 nested vmaps. +constexpr int64_t kVmapNumLevels = 64; + +// Store this number of elements of BatchDims on the stack. Most people will +// probably use <= 5 nested vmaps, but adjust this number as necessary. +constexpr int64_t kBatchDimsStackSize = 5; + +// A BatchedTensorImpl holds an underlying Tensor and a single batch dim +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +// +// The batch dimensions are treated as being "private"; they are not user-visible. +// For example, in the following Tensor, +// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) +// dimension 0 is batch dimension. +// +// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) +// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor. +struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { + explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level); + + // Returns batch dimension of this tensor + int64_t bdim() const { return bdim_; } + + // Returns batch dimension of this tensor + int64_t level() const { return level_; } + + // BatchedTensorImpl wraps a Tensor + const Tensor& value() const { return value_; } + + // Given a public dimension index, return the dimension index in the underlying + // value() tensor. + // For example, if we have + // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0) + // bt.actualDim(0) -> 1 + // bt.actualDim(1) -> 2 + // bt.actualDim(2) -> 3 + // bt.actualDim(3) -> Error + int64_t actualDim(int64_t dim, bool wrap_dim = true) const; + + IntArrayRef sizes_custom() const override; + SymIntArrayRef sym_sizes_custom() const override; + int64_t size_custom(int64_t d) const override; + c10::SymInt sym_size_custom(int64_t d) const override; + // We have to override this because we opted into CustomStrides + IntArrayRef strides_custom() const override; + SymIntArrayRef sym_strides_custom() const override; + // Override a bunch of methods inherited from TensorImpl to return error messages. + bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; +#ifdef DEBUG + bool has_storage() const override; +#endif + + void refreshTensorMetadata(); + + // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it + // accomplishes this is a hack where it is able to modify the levels of + // BatchedTensor to match the level of the current vmap transform. + void _unsafe_set_level(int64_t level) { + level_ = level; + } + + // Used in batching rule for in-place view operations that can change + // the index of the bdim (think squeeze_, unsqueeze_) + void unsafe_set_bdim(int64_t bdim) { + // NB: you MUST call refreshTensorMetadata after doing this. + bdim_ = bdim; + } + private: + // see NOTE: [BatchedTensorImpl levels invariant] + void checkInvariants() const; + const char* tensorimpl_type_name() const override; + + Tensor value_; + + int64_t level_; + int64_t bdim_; +}; + +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +inline bool isBatchedTensor(const Tensor& tensor) { + return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) || + tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor); +} + +// It is unsafe to call this on a Tensor that is not backed by a +// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible. +inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) { + return static_cast(tensor.unsafeGetTensorImpl()); +} + +inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) { + if (!isBatchedTensor(tensor)) { + return nullptr; + } + return unsafeGetBatchedImpl(tensor); +} + +// Returns a bitset. If bit i is set, then that means dim i is a batchdim. +inline std::bitset createBatchDimBitset(int64_t dim) { + std::bitset is_bdim; + is_bdim.set(dim); + return is_bdim; +} + +// Creates a bitset for the given level +inline std::bitset createVmapLevelsBitset(int64_t level) { + std::bitset result; + result.set(level); + return result; +} + +// Use this to construct a BatchedTensor from a regular Tensor +TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level); + +// Adds a batch dim to `tensor`, returning a BatchedTensor +TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level); + +// Certain dispatch keys must be propagated to the BatchedTensor (or, in general, +// any wrapper Tensor subclasses). This is because there are methods on Tensor +// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()). +// TODO: should probably contain more (or all?) backend keys +constexpr DispatchKeySet kKeysToPropagateToWrapper({ + DispatchKey::Negative, + DispatchKey::Conjugate, + DispatchKey::XLA, + DispatchKey::CUDA, + DispatchKey::CPU, + DispatchKey::PrivateUse1, +}); + +inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { + auto key_set = tensor.unsafeGetTensorImpl()->key_set(); + return key_set & kKeysToPropagateToWrapper; +} + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h new file mode 100644 index 0000000000000000000000000000000000000000..7b9c2aa151e9f11d4771a0ce52774c756497b17f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h @@ -0,0 +1,126 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +// This file contains template metaprogramming things that are used for our +// batching rules. +// +// See NOTE: [vmap plumbing] for more details on why this is necessary. +// The plumbing has a bunch of metaprogramming hacks for determining the signature +// of a batching rule from the signature of the operator, many of which use the +// helper functions in this file. + +namespace at::functorch { + +// Metaprogramming things +template using typelist = c10::guts::typelist::typelist; +template using head_t = c10::guts::typelist::head_t; +template using concat_t = c10::guts::typelist::concat_t; +template class debug_t; + +// tail operation +template +struct tail final { + static_assert(c10::guts::false_t::value, + "In typelist::tail, the T argument must be typelist<...>."); +}; +template +struct tail> final { + using type = typelist; +}; +template using tail_t = typename tail::type; + +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext { + using type = Next; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, std::optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, std::optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, std::optional, Next, Tail> { + using type = Tail; +}; +template +struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, std::optional, Next, Tail> { + using type = Tail; +}; +template struct RemoveBatchDimAfterTensor { + using first = head_t; + using next = tail_t; + using second = head_t; + using tail = tail_t; + + using type = concat_t< + typelist, + typename RemoveBatchDimAfterTensor< + typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext::type + >::type + >; +}; +template struct RemoveBatchDimAfterTensor> { + using type = typelist; +}; +template <> struct RemoveBatchDimAfterTensor> { + using type = typelist<>; +}; +template using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor::type; + +template struct UnpackSingleItemTuple { + using type = T; +}; +template struct UnpackSingleItemTuple> { + using type = T; +}; +template using unpack_single_item_tuple_t = typename UnpackSingleItemTuple::type; + +template struct BuildFunctionHelper; +template struct BuildFunctionHelper> { + using type = Return(Args...); +}; +template +struct BuildFunction { + using type = typename BuildFunctionHelper>::type; +}; +template using build_function_t = typename BuildFunction::type; + + +template struct ToOperatorType { + using batch_rule_return_type = typename c10::guts::function_traits::return_type; + using batch_rule_parameter_types = typename c10::guts::function_traits::parameter_types; + + using operator_parameter_types = remove_batch_dim_after_tensor_t; + using operator_return_type = + unpack_single_item_tuple_t< + c10::guts::typelist::to_tuple_t< + remove_batch_dim_after_tensor_t< + c10::guts::typelist::from_tuple_t>>>; + + using type = build_function_t; +}; +template using to_operator_t = typename ToOperatorType::type; + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..447df65550aa5fcda929354ca1a2ff5610232082 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h @@ -0,0 +1,124 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declared +namespace c10 { struct AutogradMetaInterface; } + +namespace at::functorch { + +// This file contains the implementation of functorch's interpreter stack. +// See NOTE: [functorch interpreter stack] first before reading on. +// +// NB: the functorch interpreter stack is also referred to as: +// - the "dynamic layer stack" -- an older name for "interpreter" was +// "dynamic layer". +// - the "functorch mode stack". You can think of each functorch transform as a +// "mode" (in the same sense as torch_dispatch mode or torch_function mode), +// and functorch being an implementation of a "mode stack" where the modes +// may be arbitrary composed. + +// DynamicLayer is basically the same thing as an Interpreter. +// It represents a functorch transform and it holds an Interpreter, +// which contains metadata related to the transform and instructions on +// how to perform the transform. +// +// TODO: we can excise DynamicLayer in favor of Interpreter, +// But I am going to leave it for now as a compatiblity shim to avoid +// needing to refactor a lot of callsites... +struct TORCH_API DynamicLayer { + explicit DynamicLayer( + TransformType transform_type, + int64_t layerId, + std::optional batchSize = std::nullopt, + std::optional randomness = std::nullopt, + std::optional prev_grad_mode = std::nullopt, + std::optional pre_fwd_grad_mode = std::nullopt, + std::optional functionalize_add_back_views = std::nullopt); + + TransformType key() const; + int64_t layerId() const; + + const Interpreter& interpreter() const { return interpreter_; } + Interpreter& interpreter() { return interpreter_; } + + // Only valid for vmap + c10::SymInt batchSize() const; + RandomnessType randomness() const; + + private: + Interpreter interpreter_; +}; + +TORCH_API int64_t initAndPushDynamicLayer( + TransformType transform_type, + std::optional batch_size = std::nullopt, + std::optional randomness = std::nullopt, + std::optional prev_grad_mode = std::nullopt, + std::optional prev_fwd_grad_mode = std::nullopt, + std::optional functionalize_add_back_views = std::nullopt); +TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata(); +TORCH_API std::optional maybeCurrentDynamicLayer(); +TORCH_API const std::vector& getDynamicLayerStack(); +TORCH_API void setDynamicLayerStack(const std::vector& stack); +TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included); + +// NOTE: [Life handles and lexically scoped transforms] +// functorch transforms are lexically scoped. +// Given a level, we store a "life handle" that is a boolean that tells us if the +// transform with that level is active or not. +// +// functorch's TensorWrapper (for grad transforms) stores a life handle. +// If a TensorWrapper escapes from the scope of the transform, then somehow +// it must know it escaped; it can tell by querying the life handle. +TORCH_API const std::shared_ptr& getLifeHandleForLevel(int64_t level); + +// Returns if an operator is in-place. An operator is inplace if: +// 1. The first argument is a Tensor and it is being written to +// 2. The first argument is being returned +// 3. No other arguments are aliased +// Here is an example of an in-place operator: +// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) +TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema); + +// Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped +TORCH_API std::optional findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input); + +TORCH_API Tensor unwrapIfDead(const Tensor& tensor); +TORCH_API bool isDeadTensorWrapper(const Tensor& tensor); + +// Pretty printers +TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); +TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); + +// While a functorch transform is active, torch.autograd.function._SingleLevelFunction +// is disabled by default. The following two APIs are APIs for enabling +// it. These are not user-facing APIs. We can delete this in the future, but +// it is useful for debugging when something goes wrong with the +// autograd.Function <> functorch interaction, which uses _SingleLevelFunction, +// because it leads to loud errors if something is incorrect. +TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed); +TORCH_API bool getSingleLevelAutogradFunctionAllowed(); + +// While a functorch grad transform is active, Tensor.requires_grad_() gets +// disabled. These two functions are the mechanism to controlling that. +TORCH_API void setInplaceRequiresGradAllowed(bool allowed); +TORCH_API bool getInplaceRequiresGradAllowed(); + +TORCH_API DynamicLayer popDynamicLayer(); +TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer); + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..f07c1f4fdc012938702b3ef3cab43e1e2ea5c05c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h @@ -0,0 +1,22 @@ +#pragma once +#include + +namespace at::functorch { + +// This is the interpreter that handles the functionalize() transform. +// See NOTE: [functorch interpreter stack] for more details. + +struct FunctionalizeInterpreterPtr { + explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + bool functionalizeAddBackViews() const { + return std::get(base_->meta()).functionalizeAddBackViews_; + } + private: + const Interpreter* base_; +}; + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..1c76230fb45566a0cc757a8daf3affe4c91f1ad1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h @@ -0,0 +1,351 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::functorch { + +// NOTE: [functorch interpreter stack] +// +// functorch's dispatching system uses a stack of interpreters. +// Historically we've referred to this as the "DynamicLayerStack". +// +// An interpreter is something that reads in the code it is passed +// and then executes it. We have a different interpreter per-transform: +// the "VmapInterpreter" is responsible for reading in operators (like aten::mv) +// and executing the batched version of it (the batching rule for aten::mv). +// +// Concretely, each interpreter is responsible for two things: +// +// 1) process(ophandle, stack) +// Given an operator handle and a stack of arguments, the interpreter is +// responsible for figuring out how to execute the operation under the semantics +// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call +// the batching rule. +// +// The batching rules are stored as kernels on the FuncTorchBatched key, so the way +// VmapInterpreter calls the batching rule is roughly: (A) exclude all +// dispatch keys aside from the Batched key, (B) redispatch so we get to the +// Batched key. +// +// 2) sendToNextInterpreter(ophandle, stack) +// The VmapInterpreter, when it sees aten::mv, will process it into a call to +// aten::mm. It then needs to send the call to aten::mm to the next interpreter +// in the interpreter stack. +// +// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack) +// and most Interpreters will implement it this way. + +enum class RandomnessType { + Error, // always errors when calling a random function + Same, // randomness appears the same across batches + Different, // randomness appears different across batches + END +}; + +enum class TransformType { + Torch, // Unused + Vmap, + Grad, // reverse-mode AD, aka vjp + Jvp, // forward-mode AD + Functionalize, +}; + +std::ostream& operator<<(std::ostream& os, const TransformType& t); + +// NOTE: [Interpreter "subclassing" design] +// +// How are various Interpreters for different transforms (vmap, grad, ...) +// implemented? +// +// Accessing interpreters is in the hot-path of functorch so we have a constraint +// that this code must be as fast as possible. +// +// As a result, we stay away from virtual methods and this causes our code +// to look a little funny. +// +// `Interpreter` is the struct for Interpreters. It holds ALL of the +// relevant information (what type of interpreter it is and the metadata). +// Metadata for each interpreter is represented as a Union (std::variant) +// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...). +// +// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this +// if you want to access the metadata fields (like batchSize and randomness). +// +// Each type of interpreter (e.g. Vmap) has a convenience struct +// (e.g. VmapInterpreterPtr) associated with it. +// +// Construct the convenience struct with VmapInterpreterPtr(Interpreter*), +// and then one can access methods on VmapInterpreterPtr like so: +// >>> VmapInterpreterPtr(&interpreter).batchSize() +// +// Finally, Interpreter::process switches on the type of the interpreter +// and calls one of {Transform}Intepreter::processImpl under the hood. +// Same for Interpreter::sendToNextInterpreter :) + +struct VmapInterpreterMeta { + explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) : + batchSize_(std::move(batchSize)), randomness_(randomness) {} + + c10::SymInt batchSize_; + RandomnessType randomness_; + + VmapInterpreterMeta() = default; + VmapInterpreterMeta(const VmapInterpreterMeta&) = default; + VmapInterpreterMeta(VmapInterpreterMeta&&) = default; + VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default; + VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default; + ~VmapInterpreterMeta() = default; + + template + friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) { + if (json_t.batchSize_.is_heap_allocated()) { + throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet"); + } + json_j["batchSize"] = json_t.batchSize_.as_int_unchecked(); + json_j["randomness"] = static_cast(json_t.randomness_); + } + + template + friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) { + json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]); + json_t.randomness_ = static_cast(json_j["randomness"]); + } +}; + +struct GradInterpreterMeta { + explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {} + GradInterpreterMeta() = default; + GradInterpreterMeta(const GradInterpreterMeta&) = default; + GradInterpreterMeta(GradInterpreterMeta&&) = default; + GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default; + GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default; + ~GradInterpreterMeta() = default; + + bool prevGradMode_; + template + friend void to_json(T& json_j, const GradInterpreterMeta& json_t) { + json_j["prevGradMode"] = json_t.prevGradMode_; + } + + template + friend void from_json(const T& json_j, GradInterpreterMeta& json_t) { + json_t.prevGradMode_ = json_j["prevGradMode"]; + } +}; + +struct JvpInterpreterMeta { + explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {} + JvpInterpreterMeta() = default; + JvpInterpreterMeta(const JvpInterpreterMeta&) = default; + JvpInterpreterMeta(JvpInterpreterMeta&&) = default; + JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default; + JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default; + ~JvpInterpreterMeta() = default; + + bool prevFwdGradMode_; + template + friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) { + json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_; + } + + template + friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) { + json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"]; + } +}; + +struct FunctionalizeInterpreterMeta { + explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) : + functionalizeAddBackViews_(functionalizeAddBackViews) {} + FunctionalizeInterpreterMeta() = default; + FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default; + FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default; + FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default; + FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default; + ~FunctionalizeInterpreterMeta() = default; + + bool functionalizeAddBackViews_; + template + friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) { + json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_; + } + + template + friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) { + json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"]; + } +}; + +typedef std::variant< + int64_t, + GradInterpreterMeta, + JvpInterpreterMeta, + VmapInterpreterMeta, + FunctionalizeInterpreterMeta +> InterpreterMeta; + + +struct Interpreter { + // factory functions + static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) { + return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness)); + } + static Interpreter Grad(int64_t level, bool prevGradMode) { + return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode)); + } + static Interpreter Jvp(int64_t level, bool prevFwdGradMode) { + return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode)); + } + static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) { + return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews)); + } + + // methods + TransformType key() const { return type_; } + int64_t level() const { return level_; } + const InterpreterMeta& meta() const { return meta_; } + + void process(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + + void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) { + TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value()); + savedLocalDispatchKeySet_ = keyset; + } + void clearSavedLocalDispatchKeySet() { + TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); + savedLocalDispatchKeySet_ = std::nullopt; + } + c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const { + TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); + return *savedLocalDispatchKeySet_; + } + + // An Interpreter is alive if we are currently inside the ongoing transform + // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's + // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack. + bool is_alive() const { + return *is_alive_; + } + const std::shared_ptr& is_alive_ptr() const { + return is_alive_; + } + void set_is_alive(bool alive) { + *is_alive_ = alive; + } + + // Please don't use this + explicit Interpreter() = default; + + template + friend void to_json(T& json_j, const Interpreter& json_t) { + json_j["type"] = static_cast(json_t.type_); + json_j["level"] = json_t.level_; + if (json_t.savedLocalDispatchKeySet_) { + json_j["savedLocalDispatchKeySet"] = { + {"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()}, + {"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()} + }; + } else { + json_j["savedLocalDispatchKeySet"] = nlohmann::json(); + } + json_j["is_alive"] = *json_t.is_alive_; + std::visit([&](auto&& arg) { + using V = std::decay_t; + if constexpr (std::is_same_v) { + json_j["meta"] = {{"Torch", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Grad", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Jvp", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Vmap", arg}}; + } else if constexpr (std::is_same_v) { + json_j["meta"] = {{"Functionalize", arg}}; + } else { + static_assert(false && sizeof(V), "unknown variant case"); + } + }, json_t.meta_); + } + + template + friend void from_json(const T& json_j, Interpreter& json_t) { + json_t.type_ = static_cast(json_j["type"]); + json_t.level_ = json_j["level"]; + auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"]; + if (savedLocalDispatchKeySet.is_null()) { + json_t.savedLocalDispatchKeySet_ = std::nullopt; + } else { + c10::impl::PODLocalDispatchKeySet pod; + pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get())); + pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get())); + json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod); + } + json_t.is_alive_ = std::make_shared(json_j["is_alive"]); + auto meta = json_j["meta"]; + if (meta.contains("Torch")) { + json_t.meta_.emplace(meta["Torch"].template get()); + } else if (meta.contains("Grad")) { + json_t.meta_.emplace(meta["Grad"].template get()); + } else if (meta.contains("Jvp")) { + json_t.meta_.emplace(meta["Jvp"].template get()); + } else if (meta.contains("Vmap")) { + json_t.meta_.emplace(meta["Vmap"].template get()); + } else if (meta.contains("Functionalize")) { + json_t.meta_.emplace(meta["Functionalize"].template get()); + } else { + throw std::runtime_error("unknown interpreter metadata type"); + } + } + + std::string serialize() const { + return nlohmann::json(*this).dump(); + } + + static Interpreter deserialize(const std::string& serialized) { + return nlohmann::json::parse(serialized).get(); + } + + private: + explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta): + type_(type), level_(level), is_alive_(std::make_shared(false)), meta_(std::move(meta)) {} + + // fields + TransformType type_{}; + int64_t level_{}; + std::optional savedLocalDispatchKeySet_; + std::shared_ptr is_alive_; + InterpreterMeta meta_; +}; + +// Applies the following for-loop: +// for i in range(begin, end): +// args[i] = func(args[i]) +void foreachTensorInplace(std::vector& args, int64_t begin, int64_t end, + std::function func); + +// Applies the following for-loop: +// for i in range(begin, end): +// if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset +// args[i] = func(args[i], i - begin, true) +// args[i] = func(args[i], i - begin) +void foreachTensorInplaceWithFlag(std::vector& args, int64_t begin, int64_t end, + const std::bitset<64> use_flag_relative, const std::function& func); + +std::vector findUnwrappedInputs(std::vector& args, int64_t begin, int64_t end); + +DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key); + +void setup_dispatch_key_tls(TransformType key, DispatchKeySet include); + +void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h new file mode 100644 index 0000000000000000000000000000000000000000..390989d45bf73fa6080d5ab4aa69963b10970759 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h @@ -0,0 +1,187 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace at::functorch { + +// This files contains the legacy (now-deprecated) batching rule API. +// Please try to use the new-style batching rule API (see writing_batch_rules.md) + +// This file contains abstractions used for transforming *logical* vmap arguments +// into *physical* arguments. (Keep reading for definitions of these terms). + +// NOTE: [Logical vs physical args] +// Consider the following vmap. +// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) +// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], +// with batch dims 0 and 2: +// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) +// +// We say the *logical* view of the tensor has size [3] -- tensors inside +// `func` appear to have size [3]. +// However, the *physical* underlying tensor (the one passed to vmap) has size +// [2, 3, 4]. +// +// This notion of logical vs physical also extends to non-tensor arguments. +// Consider the previous tensor; let's assume the user called +// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical +// dimension they are reducing over is dim 0 but the physical dim is dim 1 +// (the first non-batch dimension) + +// Forward declared; see NOTE: [What is a VmapPhysicalView?] +struct VmapPhysicalView; + +// Most PyTorch operators take 4 or fewer inputs. +constexpr int64_t kVmapTransformStaticInputSize = 4; +using VmapPhysicalViewVec = SmallVector; + +// Pytorch generally advertises good performance for <= 5 dims. +// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap +// dimensions to get 8. Adjust this number as necessary +constexpr int64_t kVmapStaticDimVecSize = 8; +using VmapDimVector = SmallVector; +using VmapSymDimVector = SmallVector; + +// NOTE: [What is an VmapTransform?] +// An *VmapTransform* converts logical views of tensors to physical views. +// +// Batching rules use VmapTransforms to convert logical arguments to +// physical arguments, then call one or more at:: operator that handles the +// physical arguments, and then converts the physical result back to a logical +// argument. + +// VmapTransform for operators that take tensors with multiple batch dims. +// Given one or more logical views on Tensors, `logicalToPhysical` +// permutes all of the batch dims to the front of the tensor, aligns +// and expands the batch dims to match each other (according to their `level`), +// and returns a VmapPhysicalView on the tensor(s). +struct TORCH_API MultiBatchVmapTransform { + static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); + static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); +}; + +// VmapTransform for operators that broadcast all inputs. +// Given some logical views on Tensors, `logicalToPhysical`: +// - permutes all of the batch dims to the front of the tensors +// - aligns all the batch dims to the collective levels of all of the tensors. +// If a tensor does not have a batch dim for a vmap level, then it receives +// a size-one dimension for said level. +// - aligns the non-batch dims to have the same dimensionality, adding extra +// size-1 dimensions in between the batch dimensions and the non-batch dimensions +// so that the batch dimensions are lined up from the right. +// +// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch +// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors +// of size (B, 1, 2) and (B, 3, 2). +// +// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns +// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't +// actually *need* to return a tensor of size (1, 2) for the second tensor +// because the broadcasting operation takes care of that for us, but we do +// it anyways to keep things simple. +struct TORCH_API BroadcastingVmapTransform { + static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); +}; + +// Forward declared, if you're reading this file head to toe, don't worry about +// it yet. +struct VmapPhysicalToLogicalMap; + +// NOTE: [What is a VmapPhysicalView?] +// VmapPhysicalView represents a physical view on a Tensor. +// +// One can use it to further convert logical dimension indices, logical shapes, +// and more to their physical variants, or convert a new (physical) tensor into +// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). +// +// VmapPhysicalView stores a physical tensor with all of its batch dimensions at +// the front and some levels that correspond to said batch dimensions. +// +// The levels bitset specifies which vmap levels correspond to the batch +// dimensions at the front of the tensor. In particular, the number of set bits +// corresponds to the number of batch dimensions on `tensor` and the rightmost +// bit of `levels` specifies the maximum number of nested vmaps we are in at +// this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 +struct TORCH_API VmapPhysicalView { + VmapPhysicalView(Tensor&& tensor, std::bitset levels) + : levels_(levels), tensor_(std::move(tensor)) { + // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor)); + } + + Tensor& tensor() { return tensor_; } + const Tensor& tensor() const { return tensor_; } + + // Maps logical dim indices to physical dim indices. Also does dim wrapping. + // + // For example, given: + // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) + // + // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. + // This is because the size of levels tell us that the first two dimensions + // of `tensor_` are batch dimensions, so a logical dim of `n` is actually + // a physical dim of `n + 2`. + VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const; + int64_t getPhysicalDim(int64_t logical_dim) const; + + // Returns a VmapPhysicalToLogicalMap object. This can be used for + // mapping a physical tensor to a new logical tensor (BatchedTensor) + VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; + + // Maps a logical shape to a physical shape by pre-pending the batch + // sizes to the logical shape. + VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; + SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const; + + int64_t numBatchDims() const; + + private: + int64_t numLogicalDims() const; + + std::bitset levels_; + Tensor tensor_; +}; + +// Convenience struct used for mapping a physical tensor (a non-BatchedTensor) +// to a logical one (BatchedTensor). It holds some levels that are used to do the +// mapping and assumes that the batch dimensions in the physical tensor all +// occur at the front of the tensor. +struct TORCH_API VmapPhysicalToLogicalMap { + VmapPhysicalToLogicalMap(std::bitset levels): levels_(levels) {} + + // Maps a physical tensor to a new logical tensor (BatchedTensor). + // Assumes that all of the "batch dimensions" are at the front + // of the physical tensor. For example, given: + // - x = rank-4 Tensor with size 2, 3, 5, 7 + // - levels = (2, 4) + // Returns: + // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) + Tensor apply(const Tensor& physical_tensor) const; + + // Given a vector of physical tensors, + // 1. maps each tensor to a new logical tensor. Assumes that all of the + // "batch dimensions" are at the front of the physical tensors. + // 2. stores the new logical tensors back into the passed-in vector. This is + // to avoid additional dynamic allocations. + void applyInplace(std::vector& physical_tensors) const; + + std::bitset levels_; +}; + + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Macros.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Macros.h new file mode 100644 index 0000000000000000000000000000000000000000..eb0a763261bf051a814c2bfc128f4edd07732bdf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Macros.h @@ -0,0 +1,3 @@ +#pragma once + +#define SINGLE_ARG(...) __VA_ARGS__ diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/PlumbingHelper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/PlumbingHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..9caa52e9652c88762fd66ef0f19cd96a6061927f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/PlumbingHelper.h @@ -0,0 +1,63 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include + +// NOTE: [vmap plumbing] +// +// Here's how "batching rules" work. +// - we register kernels to the Batched key +// - these kernels have the same signatures as the original operators. +// For example, at::sin(Tensor self) accepts a Tensor, and the batched kernel +// must also accept a Tensor +// - However, it is more natural for users to write a batching rule like the +// following: sin_batch_rule(Tensor self, std::optional self_bdim) +// - There is some codegenerated layer (the "plumbing") that wraps the user +// defined batching rule (e.g. sin_batch_rule) in a kernel that can be +// registered to the Batched key. +// +// The plumbing is responsible for wrapping a batching rule into a form that may +// be registered as the kernel for the batched key. + +namespace at::functorch { + +void vmap_check_escaped(const std::optional &layer, const char* what); + +// Create a BatchedTensor given a tensor, bdim, and level +TORCH_API Tensor makeBatched(Tensor tensor, std::optional bdim, int64_t level); + +// Given a Tensor that may or may not be a BatchedTensor, unwrap it. +// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level +// doesn't match, then this returns (tensor, std::nullopt). +// Otherwise, it returns (unwrap(tensor), bdim). +TORCH_API std::tuple> unwrapTensorAtLevel(const Tensor& tensor, int64_t level); + +// Creates a vector of BatchedTensor +TORCH_API std::vector makeBatchedVector(std::vector tensors, std::optional bdim, int64_t level); + +// Returns True if ANY tensor in tensors is batched at level +TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level); +TORCH_API bool isBatchedAtLevel(const c10::List>& maybe_tensors, int64_t level); +TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level); +TORCH_API bool isBatchedAtLevel(const std::optional& maybe_tensor, int64_t level); + +// Convenience helper. Returns true if any tensor is batched at level +TORCH_API bool areAnyBatchedAtLevel(ArrayRef> maybe_tensors, int64_t level); + +inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) { + if (ivalue.isTensor()) { + auto maybe_level = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_level.has_value()); + auto current_level = maybe_level->layerId(); + return isBatchedAtLevel(ivalue.toTensor(), current_level); + } + // TODO: should really check this + return false; +} + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/TensorWrapper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/TensorWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..bf7b14fd4168997af0329b2d458775314212f70a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/TensorWrapper.h @@ -0,0 +1,103 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace at::functorch { + +// NOTE: [functorch's TensorWrapper] +// +// Taking better suggestions for a name. TensorWrapper is the wrapper Tensor +// Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is +// analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass. +// +// If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively +// another Variable. +// +// Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)). +// The reason why is so that each TensorWrapper can hold its own AutogradMeta and +// participate in a **separate** autograd graph. +// +// There are alternative designs we could have chosen (e.g. each grad transform +// stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper +// design is that we can re-use existing VariableType kernels (i.e. Autograd kernels) +// without much modification. Since a TensorWrapper looks like a regular Tensor, +// the VariableType kernel can pull out the AutogradMeta struct from where it +// expects and extend the autograd graph + +struct TORCH_API TensorWrapper : public c10::TensorImpl { + explicit TensorWrapper( + c10::DispatchKeySet key_set, + Tensor value, + int64_t level, + std::shared_ptr is_alive, + bool is_immutable = false, // if true, this came from an operation that aliases an immutable tensor + bool use_value_sizes_strides = true); + + void refreshMetadata(); + + const Tensor& value() const { + return value_; + } + std::optional level() const { + if (is_alive()) { + return level_; + } + return {}; + } + bool is_immutable() const { + return is_immutable_; + } + bool is_alive() const; + + // Overrides necessary for autograd + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + + private: + const char* tensorimpl_type_name() const override; + Tensor value_; + int64_t level_; + bool is_immutable_; + + // TensorWrapper receives a boolean flag on whether or not the Grad Interpreter + // that created it is still alive or not. + // If the Grad Interpreter is no longer alive then it attempts to behave like + // a regular Tensor. + // + // When we exit the level, this wrapper may be marked as "not alive". + // Wrappers that are not alive: + // 1) May still have autograd metadata on them + // 2) Forward dispatches to the underlying value() + std::shared_ptr is_alive_; +}; + +// There are two variants of makeTensorWrapper: one that accepts a level +// and one that accepts an Interpreter. +// +// The one that accepts a level tries to automatically get the life handle from the +// interpreter on the DynamicLayerStack. +// It needs to be used with caution: if the interpreter is not on the +// DynamicLayerStack, then we won't be able to find the life handle. +// +// In practice this isn't a problem: when we're constructing TensorWrapper in +// Python, the corresponding interpreter is on the stack. +TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false); +TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false); +TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor); +TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor); +TORCH_API void dumpTensorCout(const Tensor& tensor); + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/VmapInterpreter.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/VmapInterpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..917c671e3ddc95836660f5f0be6365dcdde7e0b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/VmapInterpreter.h @@ -0,0 +1,25 @@ +#pragma once +#include + +namespace at::functorch { + +// This is the interpreter that handles the functionalize() transform. +// See NOTE: [functorch interpreter stack] for more details. + +struct VmapInterpreterPtr { + explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); } + TransformType key() const { return base_->key(); } + int64_t level() const { return base_->level(); } + void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); + void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); + c10::SymInt batchSize() const { + return std::get(base_->meta()).batchSize_; + } + RandomnessType randomness() const { + return std::get(base_->meta()).randomness_; + } + private: + const Interpreter* base_; +}; + +} // namespace at::functorch diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..39ab441478e8f54cd0a8f5f61e9575ce6ca0dc24 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10::hip { + +// Takes a valid HIPAllocator (of any sort) and turns it into +// an allocator pretending to be a CUDA allocator. See +// Note [Masquerading as CUDA] +class HIPAllocatorMasqueradingAsCUDA final : public Allocator { + Allocator* allocator_; +public: + explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator) + : allocator_(allocator) {} + DataPtr allocate(size_t size) override { + DataPtr r = allocator_->allocate(size); + r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index())); + return r; + } + DeleterFnPtr raw_deleter() const override { + return allocator_->raw_deleter(); + } + void copy_data(void* dest, const void* src, std::size_t count) const final { + allocator_->copy_data(dest, src, count); + } +}; + +} // namespace c10::hip diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..3aaa9d06c5e91f562382ecd56ea1d3c8a25d41af --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include + +namespace c10 { +// forward declaration +class DataPtr; +namespace hip { +namespace HIPCachingAllocatorMasqueradingAsCUDA { + +C10_HIP_API Allocator* get(); +C10_HIP_API void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream); + +} // namespace HIPCachingAllocatorMasqueradingAsCUDA +} // namespace hip +} // namespace c10 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..93b998a8f7fd6efcca58e255fa47f76a0e1e652b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -0,0 +1,383 @@ +#pragma once + +#include + +// The includes of HIPGuard.h +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// Note [Masquerading as CUDA] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// c10_hip is very easy to understand: it is HIPified from c10_cuda, +// and anywhere you said CUDA, the source code now says HIP. HIPified +// PyTorch is much harder to understand: it is HIPified from regular +// PyTorch, yes, but NO source-to-source translation from CUDA to +// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP". +// For example, when you use HIPified PyTorch, you say x.cuda() to +// move a tensor onto ROCm device. We call this situation "HIP +// masquerading as CUDA". +// +// This leads to a very awkward situation when we want to call c10_hip +// code from PyTorch, since c10_hip is expecting things to be called +// HIP, but PyTorch is calling them CUDA (masquerading as HIP). To +// fix this impedance mismatch, we have MasqueradingAsCUDA variants +// for all c10_hip classes. These translate between the "HIP" and "CUDA +// masquerading as HIP" worlds. For example, +// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a +// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type() +// returns CUDA, getDevice() reports the current HIP device as a CUDA +// device.) +// +// We should be able to delete all of these classes entirely once +// we switch PyTorch to calling a HIP a HIP. +// +// When you add a new MasqueradingAsCUDA class/function, you need to +// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py +// +// +// +// By the way, note that the cpp file associated with this also +// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with +// this HIP implementation. + +struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA; + HIPGuardImplMasqueradingAsCUDA() {} + HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA); + } + c10::DeviceType type() const override { + return c10::DeviceType::CUDA; + } + Device exchangeDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_cuda()); + Device old_device = getDevice(); + if (old_device.index() != d.index()) { + C10_HIP_CHECK(hipSetDevice(d.index())); + } + return old_device; + } + Device getDevice() const override { + int device; + C10_HIP_CHECK(hipGetDevice(&device)); + return Device(c10::DeviceType::CUDA, device); + } + void setDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_cuda()); + C10_HIP_CHECK(hipSetDevice(d.index())); + } + void uncheckedSetDevice(Device d) const noexcept override { + C10_HIP_CHECK_WARN(hipSetDevice(d.index())); + } + Stream getStream(Device d) const override { + return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap(); + } + Stream getDefaultStream(Device d) const override { + return getDefaultHIPStreamMasqueradingAsCUDA(d.index()); + } + Stream getNewStream(Device d, int priority = 0) const override { + return getStreamFromPoolMasqueradingAsCUDA(priority, d.index()); + } + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override { + return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index()); + } + Stream exchangeStream(Stream s) const override { + HIPStreamMasqueradingAsCUDA cs(s); + auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index()); + setCurrentHIPStreamMasqueradingAsCUDA(cs); + return old_stream.unwrap(); + } + DeviceIndex deviceCount() const noexcept override { + int deviceCnt; + hipError_t _err; + _err = hipGetDeviceCount(&deviceCnt); + if(_err != hipErrorNoDevice && _err != hipSuccess) + C10_HIP_CHECK(_err); + return deviceCnt; + } + + // Event-related functions + // Note: hipEventCreateWithFlags should be called on the same device as + // the recording stream's device. + void createEvent( + hipEvent_t* hip_event, + const EventFlag flag) const { + // Maps PyTorch's Event::Flag to HIP flag + auto hip_flag = hipEventDefault; + switch (flag) { + case EventFlag::PYTORCH_DEFAULT: + hip_flag = hipEventDisableTiming; + break; + case EventFlag::BACKEND_DEFAULT: + hip_flag = hipEventDefault; + break; + default: + TORCH_CHECK(false, "HIP event received unknown flag"); + } + + C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag)); + } + + void destroyEvent( + void* event, + const DeviceIndex device_index) const noexcept override { + if (!event) return; + auto hip_event = static_cast(event); + int orig_device; + C10_HIP_CHECK_WARN(hipGetDevice(&orig_device)); + C10_HIP_CHECK_WARN(hipSetDevice(device_index)); + C10_HIP_CHECK_WARN(hipEventDestroy(hip_event)); + C10_HIP_CHECK_WARN(hipSetDevice(orig_device)); + } + + void record(void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + hipEvent_t hip_event = static_cast(*event); + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + + // Moves to stream's device to record + const auto orig_device = getDevice(); + setDevice(stream.device()); + + // Creates the event (lazily) + if (!hip_event) createEvent(&hip_event, flag); + C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream)); + // Makes the void* point to the (possibly just allocated) HIP event + *event = hip_event; + + // Resets device + setDevice(orig_device); + } + + void block( + void* event, + const Stream& stream) const override { + if (!event) return; + hipEvent_t hip_event = static_cast(event); + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + const auto orig_device = getDevice(); + setDevice(stream.device()); + C10_HIP_CHECK(hipStreamWaitEvent( + hip_stream, + hip_event, + /*flags (must be zero)=*/ 0)); + setDevice(orig_device); + } + + bool queryEvent(void* event) const override { + if (!event) return true; + hipEvent_t hip_event = static_cast(event); + const hipError_t err = hipEventQuery(hip_event); + if (err != hipErrorNotReady) C10_HIP_CHECK(err); + else { + // ignore and clear the error if not ready + (void)hipGetLastError(); + } + return (err == hipSuccess); + } + + // Stream-related functions + bool queryStream(const Stream& stream) const override { + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + return hip_stream.query(); + } + + void synchronizeStream(const Stream& stream) const override { + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + hip_stream.synchronize(); + } + + void synchronizeEvent(void* event) const override { + if (!event) + return; + hipEvent_t hip_event = static_cast(event); + C10_HIP_CHECK(hipEventSynchronize(hip_event)); + } + + // Note: synchronizeDevice can be safely called from any device + void synchronizeDevice(const c10::DeviceIndex device_index) const override { + int orig_device{-1}; + C10_HIP_CHECK(hipGetDevice(&orig_device)); + C10_HIP_CHECK(hipSetDevice(device_index)); + C10_HIP_CHECK(hipDeviceSynchronize()); + C10_HIP_CHECK(hipSetDevice(orig_device)); + } + + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const Stream& stream) const override { + HIPStreamMasqueradingAsCUDA hip_stream{stream}; + HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream); + } + + double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) + const override { + TORCH_CHECK( + event1 && event2, + "Both events must be recorded before calculating elapsed time."); + int orig_device; + C10_HIP_CHECK(hipGetDevice(&orig_device)); + C10_HIP_CHECK(hipSetDevice(device_index)); + hipEvent_t hip_event1 = static_cast(event1); + hipEvent_t hip_event2 = static_cast(event2); + float time_ms = 0; + // raise hipErrorNotReady if either event is recorded but not yet completed + C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2)); + C10_HIP_CHECK(hipSetDevice(orig_device)); + return static_cast(time_ms); + } +}; + +// All of the guards which have HIPGuardImpl burned in need to also have +// variants using HIPGuardImplMasqueradingAsCUDA. + +/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with +/// the correct InlineDeviceGuard burned in. Sorry about the +/// copy-pasting. + +struct HIPGuardMasqueradingAsCUDA { + explicit HIPGuardMasqueradingAsCUDA() = delete; + explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {} + explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {} + + HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete; + HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete; + HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete; + HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete; + + void set_device(Device device) { guard_.set_device(device); } + void reset_device(Device device) { guard_.reset_device(device); } + void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + Device original_device() const { return guard_.original_device(); } + Device current_device() const { return guard_.current_device(); } + + private: + c10::impl::InlineDeviceGuard guard_; +}; + +struct OptionalHIPGuardMasqueradingAsCUDA { + explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {} + explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional device_opt) : guard_(device_opt) {} + explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional device_index_opt) : guard_(device_index_opt) {} + + OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; + OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; + OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; + OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; + + void set_device(Device device) { guard_.set_device(device); } + void reset_device(Device device) { guard_.reset_device(device); } + void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + std::optional original_device() const { return guard_.original_device(); } + std::optional current_device() const { return guard_.current_device(); } + void reset() { guard_.reset(); } + +private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + +struct HIPStreamGuardMasqueradingAsCUDA { + explicit HIPStreamGuardMasqueradingAsCUDA() = delete; + explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} + HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete; + HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete; + HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; + HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; + + void reset_stream(Stream stream) { guard_.reset_stream(stream); } + + HIPStreamMasqueradingAsCUDA original_stream() const { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream()); + } + HIPStreamMasqueradingAsCUDA current_stream() const { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream()); + } + + Device current_device() const { return guard_.current_device(); } + Device original_device() const { return guard_.original_device(); } + +private: + c10::impl::InlineStreamGuard guard_; +}; + +struct OptionalHIPStreamGuardMasqueradingAsCUDA { + explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {} + explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} + explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional stream_opt) : guard_(stream_opt) {} + + OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; + + void reset_stream(Stream stream) { guard_.reset_stream(stream); } + + std::optional original_stream() const { + auto r = guard_.original_stream(); + if (r.has_value()) { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()); + } else { + return std::nullopt; + } + } + + std::optional current_stream() const { + auto r = guard_.current_stream(); + if (r.has_value()) { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()); + } else { + return std::nullopt; + } + } + + void reset() { guard_.reset(); } + +private: + c10::impl::InlineOptionalStreamGuard guard_; +}; + +struct HIPMultiStreamGuardMasqueradingAsCUDA { + explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef streams) + : guard_(unwrapStreams(streams)) {} + + HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete; + HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete; + HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete; + HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete; + +private: + c10::impl::InlineMultiStreamGuard guard_; + + static std::vector unwrapStreams(ArrayRef hipStreams) { + std::vector streams; + streams.reserve(hipStreams.size()); + for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) { + streams.push_back(hipStream); + } + return streams; + } +}; + +}} // namespace c10::hip diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..fb13ada5ad88e18a02225add10f54eef31ca83f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h @@ -0,0 +1,135 @@ +#pragma once + +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// See Note [Masquerading as CUDA] for motivation + +class HIPStreamMasqueradingAsCUDA { +public: + + enum Unchecked { UNCHECKED }; + + explicit HIPStreamMasqueradingAsCUDA(Stream stream) + : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) { + // We did the coercion unchecked; check that it was right. + TORCH_CHECK(stream.device().is_cuda() /* !!! */); + } + + explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream) + // Unsafely coerce the "CUDA" stream into a HIP stream + : stream_( + HIPStream( + Stream( + Stream::UNSAFE, + Device(c10::DeviceType::HIP, stream.device_index()), + stream.id()) + ) + ) {} + + // New constructor, just for this. Does NOT coerce. + explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {} + + bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept { + return stream_ == other.stream_; + } + + bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept { + return stream_ != other.stream_; + } + + operator hipStream_t() const { return stream_.stream(); } + + operator Stream() const { + // Unsafely coerce HIP stream into a "CUDA" stream + return Stream(Stream::UNSAFE, device(), id()); + } + + DeviceIndex device_index() const { return stream_.device_index(); } + + // Unsafely coerce HIP device into CUDA device + c10::DeviceType device_type() const { return c10::DeviceType::CUDA; } + + Device device() const { + // Unsafely coerce HIP device into CUDA device + return Device(c10::DeviceType::CUDA, stream_.device_index()); + } + + StreamId id() const { return stream_.id(); } + bool query() const { return stream_.query(); } + void synchronize() const { stream_.synchronize(); } + int priority() const { return stream_.priority(); } + hipStream_t stream() const { return stream_.stream(); } + + Stream unwrap() const { + // Unsafely coerce HIP stream into "CUDA" stream + return Stream(Stream::UNSAFE, device(), id()); + } + + c10::StreamData3 pack3() const noexcept { + // Unsafely coerce HIP stream into "CUDA" stream before packing + return unwrap().pack3(); + } + + static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id, + DeviceIndex device_index, + c10::DeviceType device_type) { + // NB: constructor manages CUDA->HIP translation for us + return HIPStreamMasqueradingAsCUDA(Stream::unpack3( + stream_id, device_index, device_type)); + } + + static std::tuple priority_range() { return HIPStream::priority_range(); } + + // New method, gets the underlying HIPStream + HIPStream hip_stream() const { return stream_; } + +private: + HIPStream stream_; +}; + +HIPStreamMasqueradingAsCUDA +inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) { + return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device)); +} + +HIPStreamMasqueradingAsCUDA +inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) { + return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device)); +} + +HIPStreamMasqueradingAsCUDA +inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) { + return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device)); +} + +inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { + return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index)); +} + +inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { + return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index)); +} + +inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) { + setCurrentHIPStream(stream.hip_stream()); +} + +inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) { + stream << s.hip_stream() << " (masquerading as CUDA)"; + return stream; +} + +}} // namespace c10::hip + +namespace std { + template <> + struct hash { + size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept { + return std::hash{}(s.unwrap()); + } + }; +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Descriptors.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Descriptors.h new file mode 100644 index 0000000000000000000000000000000000000000..a0ad4a4e1098a04b85f41b3b5e01c86f2f2007ee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Descriptors.h @@ -0,0 +1,169 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace at { namespace native { + +inline int dataSize(miopenDataType_t dataType) +{ + switch (dataType) { + case miopenHalf: return 2; + case miopenFloat: return 4; + case miopenBFloat16: return 2; + default: return 8; + } +} + +template +struct DescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + MIOPEN_CHECK(dtor(x)); + } + } +}; + +// A generic class for wrapping MIOpen descriptor types. All you need +// is to give the underlying type the Descriptor_t points to (usually, +// if it's miopenTensorDescriptor_t it points to miopenTensorStruct), +// the constructor and the destructor. Subclasses are responsible +// for defining a set() function to actually set the descriptor. +// +// Descriptors default construct to a nullptr, and have a descriptor +// initialized the first time you call set() or any other initializing +// function. +template +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_CUDA_CPP_API Descriptor { + public: + // Use desc() to access the underlying descriptor pointer in + // a read-only fashion. Most client code should use this. + // If the descriptor was never initialized, this will return + // nullptr. + T* desc() const { return desc_.get(); } + T* desc() { return desc_.get(); } + + // Use mut_desc() to access the underlying descriptor pointer + // if you intend to modify what it points to (e.g., using + // miopenSetFooDescriptor). This will ensure that the descriptor + // is initialized. Code in this file will use this function. + T* mut_desc() { init(); return desc_.get(); } +protected: + void init() { + if (desc_ == nullptr) { + T* raw_desc = nullptr; + MIOPEN_CHECK(ctor(&raw_desc)); + desc_.reset(raw_desc); + } + } +private: + std::unique_ptr> desc_; +}; + +class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + TensorDescriptor() = default; + explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { + set(t, pad); + } + + void set(const at::Tensor &t, size_t pad = 0); + void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0); + + void print(); + +private: + void set(miopenDataType_t dataType, int dim, int* size, int* stride) { + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + } +}; + +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); + +class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + void set(const at::Tensor &t, int64_t pad = 0) { + set(t, at::MemoryFormat::Contiguous, pad); + } + + void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0); + +private: + void set(miopenDataType_t dataType, int dim, int* size, int* stride) { + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + } +}; + +struct TORCH_CUDA_CPP_API ConvolutionDescriptor + : public Descriptor< + miopenConvolutionDescriptor, + &miopenCreateConvolutionDescriptor, + &miopenDestroyConvolutionDescriptor> { + void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) { + MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode)); + MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups)); + MIOPEN_CHECK(miopenSetConvolutionAttribute(mut_desc(), MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, deterministic ? 1 : 0)); + if (benchmark) { + MIOPEN_CHECK(miopenSetConvolutionFindMode(mut_desc(), miopenConvolutionFindModeNormal)); + } + } +}; + +// NOLINTNEXTLINE(bugprone-exception-escape) +struct TORCH_CUDA_CPP_API DropoutDescriptor + : public Descriptor< + miopenDropoutDescriptor, + &miopenCreateDropoutDescriptor, + &miopenDestroyDropoutDescriptor> { + void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } + + void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } +}; + +struct TORCH_CUDA_CPP_API RNNDescriptor + : public Descriptor +{ + void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode, + miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); + } + + void setWithDropout(DropoutDescriptor& dropout_desc, int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, + miopenRNNMode_t rnn_mode, miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + MIOPEN_CHECK(miopenSetRNNDescriptor_V2(mut_desc(), hidden_size, num_layers, dropout_desc.mut_desc(), input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); + } +}; + +union Constant +{ + float f; + double d; + Constant(miopenDataType_t dataType, double value) { + if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) { + f = static_cast(value); + } else { + d = value; + } + } +}; + +}} // namespace diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Exceptions.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..f5f0a4785b1cf2d43fcefb03c7039e3be5643636 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Exceptions.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { namespace native { + +class miopen_exception : public std::runtime_error { +public: + miopenStatus_t status; + miopen_exception(miopenStatus_t status, const char* msg) + : std::runtime_error(msg) + , status(status) {} + miopen_exception(miopenStatus_t status, const std::string& msg) + : std::runtime_error(msg) + , status(status) {} +}; + +inline void MIOPEN_CHECK(miopenStatus_t status) +{ + if (status != miopenStatusSuccess) { + if (status == miopenStatusNotImplemented) { + throw miopen_exception(status, std::string(miopenGetErrorString(status)) + + ". This error may appear if you passed in a non-contiguous input."); + } + throw miopen_exception(status, miopenGetErrorString(status)); + } +} + +inline void HIP_CHECK(hipError_t error) +{ + if (error != hipSuccess) { + std::string msg("HIP error: "); + msg += hipGetErrorString(error); + throw std::runtime_error(msg); + } +} + +}} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Handle.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Handle.h new file mode 100644 index 0000000000000000000000000000000000000000..4c80c3aea65bf6b1317c358d03176e2a1e7ba6db --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Handle.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle(); +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Types.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Types.h new file mode 100644 index 0000000000000000000000000000000000000000..0a8a1a952e2e2812647247108b740373b4be0245 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Types.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); + +int64_t miopen_version(); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a0ec83d976bc89b91a23f9c54bcd0a03bae729a4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Utils.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include + +namespace at { namespace native { + +// This function makes tensors which have zero stride contiguous, by +// setting the strides to 1. +inline Tensor contiguousIfZeroInStrides(const Tensor& t) { + for (auto s : t.strides()) { + if (s == 0) return t.contiguous(); + } + return t; +} + +}} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/miopen-wrapper.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/miopen-wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..d1976da873ed7263e7bb4efe638fb18862d0379f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/miopen-wrapper.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +#if MIOPEN_VERSION_MAJOR > 3 || (MIOPEN_VERSION_MAJOR == 3 && MIOPEN_VERSION_MINOR >= 4) +// miopen 3.4 moved find mode from private header to public header +#else +// from miopen_internal.h +extern "C" { + +typedef enum +{ + miopenConvolutionFindModeNormal = 1, /*!< Normal mode */ +} miopenConvolutionFindMode_t; + +miopenStatus_t miopenSetConvolutionFindMode( + miopenConvolutionDescriptor_t convDesc, + miopenConvolutionFindMode_t findMode); +} +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1d0215fbfc5dcb1888b641f5fea1caeadd13317b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include + +#ifdef USE_FBGEMM +#include +#include +#include + + +namespace ao::sparse { + +struct TORCH_API PackedLinearWeight + : public LinearPackedParamsBase { + PackedLinearWeight(std::unique_ptr> w, + std::optional bias, + std::vector col_offsets, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme, + const int64_t out_features_block_size /* block sparsity size across output_features */, + const int64_t in_features_block_size /* block sparsity size across input_features */) + : LinearPackedParamsBase( + out_features_block_size, + in_features_block_size), + w(std::move(w)), + bias_(std::move(bias)), + col_offsets(std::move(col_offsets)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(q_scheme) {} + std::unique_ptr> w; + std::optional bias_; + std::vector col_offsets; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(const at::Tensor& input) override { + TORCH_INTERNAL_ASSERT( + false, + "Sparse quantized dynamic linear with fused relu is not yet " + "supported on qnnpack backend."); + return at::Tensor(); + } + at::Tensor apply_dynamic_relu(const at::Tensor& input) override { + TORCH_INTERNAL_ASSERT( + false, + "Sparse quantized dynamic linear with fused relu is not yet " + "supported on qnnpack backend."); + return at::Tensor(); + } + + LinearPackedSerializationType unpack() override; + + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + const at::Tensor& weight, + const std::optional& bias, + const int64_t out_features_block_size, + const int64_t in_features_block_size); + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +} // namespace ao::sparse + +#endif // USE_FBGEMM + +namespace ao::sparse { +int register_linear_params(); +} // namespace ao::sparse diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h new file mode 100644 index 0000000000000000000000000000000000000000..14f98b5a49782ba608d2ab160a92e8185969e8df --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include + +namespace ao::sparse { + +// +using LinearPackedSerializationType = + std::tuple, std::vector>; + +#define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 2 + +using BCSRSerializationType = + std::tuple< + int64_t, // Serialization Version + std::optional, // Bias + int64_t, // Out Features (Row) Block Size + int64_t, // In Features (Column) Block Size + at::Tensor, // Weight Scales (single element vector if per-tensor) (float) + at::Tensor, // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t) + bool, // Quantization Scheme (true: per tensor, false: per channel) + at::Tensor, // Wrapper for Row Block Indices (int8_t, int16_t, or int32_t) + at::Tensor, // Wrapper for Column Block Indices (int8_t, int16_t, or int32_t) + at::Tensor, // Wrapper for Non-Zero Weight Values, each +128 (uint8_t) + int64_t, // Number of Output Channels + int64_t // Number of Input Channels + >; + +using BCSR = + std::tuple< + std::vector, // Non-Zero Weight Values + std::vector, // Compressed Row Block Indices + std::vector // Column Block Indices + >; + +struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { + public: + LinearPackedParamsBase( + const int64_t out_features_block_size, + const int64_t in_features_block_size) + : out_features_block_size_(out_features_block_size), + in_features_block_size_(in_features_block_size) {} + + virtual at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + + virtual at::Tensor apply_dynamic(const at::Tensor& input) = 0; + virtual at::Tensor apply_dynamic_relu(const at::Tensor& input) = 0; + + virtual LinearPackedSerializationType unpack() = 0; + + virtual BCSRSerializationType serialize() = 0; + + virtual std::optional bias() = 0; + + virtual void set_bias(const std::optional& bias) { + throw std::runtime_error( + "set_bias is not implemented for this packed " + "parameter type"); + } + + protected: + const int64_t out_features_block_size_, in_features_block_size_; +}; + +} // namespace ao::sparse diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1a3515d5a7ae1b6039dff466394b6afe103b0fb2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include + +#ifdef USE_PYTORCH_QNNPACK +// TODO: Refacto QnnpackUtils.h so as to separate code +// needed for quantized op from the generic qnnpack specific +// quantization utilities. +#include +#include +#include + +namespace ao::sparse { + +struct TORCH_API PackedLinearWeightQnnp : public LinearPackedParamsBase { + PackedLinearWeightQnnp(const at::Tensor& weight, const std::optional& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */); + explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized); + std::optional orig_bias_; + // Separate copy of bias exist so that we can fill in zeros when + // optional bias does not exist. This is to compy with qnnpack operator that + // expects bias to be present. + // In case bias is present bias_ is just a reference to orig_bias_ + at::Tensor bias_; + c10::QScheme q_scheme_; + double input_scale_{}; + std::unique_ptr bcsr_matrix_; + at::Tensor w_scales_; + std::vector w_zero_points_; + std::vector requantization_scales_; + std::unique_ptr + sparse_linear_op_{nullptr}; + int64_t output_channels_; + int64_t input_channels_; + // Deserialized Tensors are stored to maintain the lifetime of underlying + // BCSR data. + // These are left empty if PackedLinearWeightQnnp is created via prepacking + // rather than deserializing. + at::Tensor deserialized_bcsr_row_block_indices_; + at::Tensor deserialized_bcsr_col_block_indices_; + at::Tensor deserialized_bcsr_weight_values_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override { + TORCH_CHECK( + false, "Static quantized sparse linear unimplemented on QNNPACK"); + } + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override { + TORCH_CHECK( + false, "Static quantized sparse linear unimplemented on QNNPACK"); + } + + at::Tensor apply_dynamic(const at::Tensor& input) override; + at::Tensor apply_dynamic_relu(const at::Tensor& input) override; + + LinearPackedSerializationType unpack() override; + + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + + std::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + const at::Tensor& weight, + const std::optional& bias, + const int64_t out_features_block_size, + const int64_t in_features_block_size); + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + template + at::Tensor apply_dynamic_impl(const at::Tensor& input); +}; + +} // namespace ao::sparse + +#endif // USE_PYTORCH_QNNPACK diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h new file mode 100644 index 0000000000000000000000000000000000000000..5b24ee4821c45baab25f37a3bfa3399eff8a1716 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h @@ -0,0 +1,37 @@ +#ifndef ATOMIC_ADD_FLOAT +#define ATOMIC_ADD_FLOAT + +#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__)) +#include +#else +#define _mm_pause() +#endif + +#include + +static inline void cpu_atomic_add_float(float* dst, float fvalue) +{ + typedef union { + unsigned intV; + float floatV; + } uf32_t; + + uf32_t new_value, old_value; + std::atomic* dst_intV = (std::atomic*)(dst); + + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + + unsigned* old_intV = (unsigned*)(&old_value.intV); + while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) { +#ifdef __aarch64__ + __asm__ __volatile__("yield;" : : : "memory"); +#else + _mm_pause(); +#endif + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + } +} + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/CatKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/CatKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..29fc21eb1ccf9ef432026feac82f013070637194 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/CatKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t); +DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c3d8990220831ec61b22cf221ad26478dcaf0b64 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t); +DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel) + +} // at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/CopyKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/CopyKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3378e16f93d23e6c317b98f4469e660086b0082a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/CopyKernel.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace at { +struct TensorIteratorBase; + +namespace native { +inline namespace CPU_CAPABILITY { + +void direct_copy_kernel(TensorIteratorBase &iter); +void copy_kernel(TensorIterator& iter, bool /*non_blocking*/); + +}}} // namespace at::native::CPU_CAPABILITY diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ac2a0423af113dbc583597d12993d4402d682194 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +/* + Depthwise 3x3 Winograd convolution operator +*/ + +namespace at { +class Tensor; + +namespace native { + +using convolution_depthwise3x3_winograd_fn = + Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t); + +DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub) + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..8171ae8e79ad2a1311f4a8600decd202c66649d5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h @@ -0,0 +1,425 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPU_CAPABILITY_AVX2 +#include +#include +#endif + + + + +namespace at::native::templates::cpu { +namespace { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) { + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t { + uniform_int_from_to_distribution random(range, base); + return random(generator); + }); + }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_full_range_distribution random; + return random(generator); + }); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG generator) { + std::lock_guard lock(generator->mutex_); + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] { + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_distribution random; + return random(generator); + }); + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_kernel(iter, check_generator(gen)); + } +}; + +// ==================================================== Normal ======================================================== + +#ifdef CPU_CAPABILITY_AVX2 +static void normal_fill_16_AVX2(float *data, + const __m256* two_pi, + const __m256* one, + const __m256* minus_two, + const __m256* mean, + const __m256* std_v) { + const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data)); + const __m256 u2 = _mm256_loadu_ps(data + 8); + // sincos256_ps and log256_ps are from avx_mathfun.h + const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1))); + const __m256 theta = _mm256_mul_ps(*two_pi, u2); + __m256 sintheta, costheta; + sincos256_ps(theta, &sintheta, &costheta); + const __m256 n1 = _mm256_mul_ps(radius, costheta); + const __m256 n2 = _mm256_mul_ps(radius, sintheta); + _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean)); + _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean)); +} + +template +void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) { + float *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi); + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 minus_two = _mm256_set1_ps(-2.0f); + const __m256 mean_v = _mm256_set1_ps(mean); + const __m256 std_v = _mm256_set1_ps(std); + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v); + } + + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v); + } +} +#endif + +template +static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) { + for (const auto j : c10::irange(8)) { + const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log. + const scalar_t u2 = data[j + 8]; + const scalar_t radius = std::sqrt(-2 * std::log(u1)); + const scalar_t theta = 2.0f * c10::pi * u2; + data[j] = radius * std::cos(theta) * std + mean; + data[j + 8] = radius * std::sin(theta) * std + mean; + } +} + +#if defined(__VSX__) || defined(CPU_CAPABILITY_VSX) +static void normal_fill_16_VSX(float *data,const Vectorized &two_pi,const Vectorized &one,const Vectorized &minus_two,const Vectorized &mean,const Vectorized &std) { + using Vec = Vectorized; + Vec u1=one-Vec::loadu(data); + Vec u2=Vec::loadu(data+8); + Vec radius=(minus_two * u1.log()); + radius=radius.sqrt(); + Vec theta=two_pi * u2; + Vec output_vec=radius * theta.cos() * std + mean; + Vec output_vec2=radius * theta.sin() * std + mean; + output_vec.store(data); + output_vec2.store(data+8); +} + +template +void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { + float *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + + using Vec = Vectorized; + const Vec two_pi = Vec(2.0f * c10::pi); + const Vec one = Vec(1.0f); + const Vec minus_two = Vec(-2.0f); + const Vec var_vec = Vec(std); + const Vec mean_vec = Vec(mean); + + for (int64_t i = 0; i < size - 15; i += 16) { + if(Vec::size()==8) { + normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec); + } + else{ + normal_fill_16(data + i, mean, std); + } + } + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + if(Vec::size()==8){ + normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec); + } + else{ + normal_fill_16(data, mean, std); + } + } +} +#endif //VSX + +template +void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { + scalar_t *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16(data + i, mean, std); + } + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16(data, mean, std); + } +} + +template +void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) { + auto size = self.numel(); + if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) { +#ifdef CPU_CAPABILITY_AVX2 + normal_fill_AVX2(self, static_cast(mean), static_cast(std), generator); +#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) + normal_fill_VSX(self, static_cast(mean), static_cast(std), generator); +#else + normal_fill(self, static_cast(mean), static_cast(std), generator); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { + if (size >= 16 && self.is_contiguous()) { + normal_fill(self, static_cast(mean), static_cast(std), generator); + } else { + auto iter = TensorIterator::borrowing_nullary_op(self); + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t { + at::normal_distribution normal(mean, std); + return static_cast(normal(generator)); + }); + } + }); + } +} + +template +struct NormalKernel { + void operator()(Tensor& self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================= + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + auto from = static_cast(from_); + auto to = static_cast(to_); + at::uniform_real_distribution uniform(from, to); + cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t { + return static_cast(uniform(generator)); + }); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::cauchy_distribution cauchy(median, sigma); + cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t { + return static_cast(cauchy(generator)); + }); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::lognormal_distribution logNormal(mean, std); + cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t { + return static_cast(logNormal(generator)); + }); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::geometric_distribution geometric(p); + cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t { + return static_cast(geometric(generator)); + }); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::exponential_distribution exponential(lambda); + cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t { + return static_cast(exponential(generator)); + }); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ================================================== Bernoulli ======================================================= + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + using self_t = scalar_t; + auto p_cpu = p_.to(kCPU); + auto p = expand_inplace(self, p_cpu); + auto iter = TensorIteratorConfig() + .add_output(self) + .add_const_input(*p) + .check_all_same_dtype(false) + .build(); + if (p->scalar_type() == kDouble) { + cpu_serial_kernel(iter, [&](const double p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, + p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] { + using p_t = scalar_t; + cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + }); + } + }); +} + +template +void bernoulli_kernel(const TensorBase &self, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_scalar_cpu_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + auto iter = TensorIterator::borrowing_nullary_op(self); + cpu_serial_kernel(iter, [p, generator]() -> scalar_t { + at::bernoulli_distribution bernoulli(p); + return static_cast(bernoulli(generator)); + }); + }); +} + +template +struct BernoulliKernel { + void operator()(const TensorBase &self, double p, std::optional gen) { + bernoulli_kernel(self, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Elu.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Elu.h new file mode 100644 index 0000000000000000000000000000000000000000..d438d4303709b1b89a45f83e118135499d299ded --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Elu.h @@ -0,0 +1,74 @@ +#pragma once + +// On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to +// access constants such as M_SQRT2 and M_2_SQRTPI. +#ifdef _WIN32 +#define _USE_MATH_DEFINES +#include +#endif // _WIN32 + +#include +#include // For c10::is_reduced_floating_point_v. + +namespace at::native { +inline namespace CPU_CAPABILITY { +/** + * Return a function object that calculates ELU with the given + * parameters on its input element. ParamT is the type of the input + * and output to the ELU, and MathT is the type (possibly + * higher-precision, e.g. float if ParamT is reduced-precision float) + * in which to do intermediate calculations. + */ +template +auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) { + const auto negcoef = alpha * scale; + const auto poscoef = scale; + const auto negiptcoef = input_scale; + return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT { + return MathT(a) < MathT(0) + ? std::expm1(MathT(a) * negiptcoef) * negcoef + : MathT(a) * poscoef; + }; +} + +/** + * Return a function object that calculates ELU with the given + * parameters on its input element. The function object takes and + * returns Vectorized. + */ +template , bool> = true> +auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) { + const vec::Vectorized negcoef_vec(alpha * scale); + const vec::Vectorized poscoef_vec(scale); + const vec::Vectorized negiptcoef_vec(input_scale); + const vec::Vectorized zero_vec(static_cast(0)); + return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized a) -> vec::Vectorized { + const auto cmp = a >= zero_vec; + if (!cmp.zero_mask()) { + return a * poscoef_vec; + } else { + return vec::Vectorized::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp); + } + }; +} + +/** + * Return a function object that calculates ELU with the given + * parameters on its input element. The function object takes and + * returns Vectorized, and Vectorized is the type + * (possibly higher-precision) in which to do intermediate + * calculations. + */ +template , bool> = true> +auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) { + // Takes float->float. + const auto float_func = get_vectorized_elu_elementwise_func(alpha, scale, input_scale); + return [float_func](vec::Vectorized a) -> vec::Vectorized { + auto [a0, a1] = vec::convert_to_float(a); + auto res0 = float_func(a0); + auto res1 = float_func(a1); + return vec::convert_from_float(res0, res1); + }; +} +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Gelu.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Gelu.h new file mode 100644 index 0000000000000000000000000000000000000000..613c69e225864eb9ef82c5b8a8b76c95e2168eae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Gelu.h @@ -0,0 +1,83 @@ +#pragma once + +// On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to +// access constants such as M_SQRT2 and M_2_SQRTPI. +#ifdef _WIN32 +#define _USE_MATH_DEFINES +#include +#include +#endif // _WIN32 + +#include +#include // For c10::is_reduced_floating_point_v. + +namespace at::native { +inline namespace CPU_CAPABILITY { +constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5; +constexpr double kGeluKappa = 0.044715; + +template +using reduced_fp_to_float_t = std::conditional_t, float, T>; + +template , bool> = true> +float reduced_fp_to_float(T x) { + return float(x); +} + +template , bool> = true> +T reduced_fp_to_float(T x) { + return x; +} + +template +T scalar_gelu_approximated_with_tanh(T x) { + using opmath_t = reduced_fp_to_float_t; + auto x_float = reduced_fp_to_float(x); + auto x_cube = x_float * x_float * x_float; + auto inner = opmath_t(kGeluBeta) * (x_float + opmath_t(kGeluKappa) * x_cube); + return opmath_t(0.5) * x_float * (opmath_t(1) + std::tanh(inner)); +} + +template , bool> = true> +vec::Vectorized vectorized_gelu_approximated_with_tanh(vec::Vectorized x) { + const vec::Vectorized kPointFiveVec(T(0.5)); + const vec::Vectorized kOneVec(T(1)); + const vec::Vectorized kGeluBetaVec((T(kGeluBeta))); + const vec::Vectorized kGeluKappaVec((T(kGeluKappa))); + auto x_cube = x * x * x; + vec::Vectorized inner_vec = kGeluBetaVec * (x + kGeluKappaVec * x_cube); + return kPointFiveVec * x * (kOneVec + inner_vec.tanh()); +} + +template , bool> = true> +vec::Vectorized vectorized_gelu_approximated_with_tanh(vec::Vectorized x) { + auto [x0, x1] = at::vec::convert_to_float(x); + return at::vec::convert_from_float( + vectorized_gelu_approximated_with_tanh(x0), + vectorized_gelu_approximated_with_tanh(x1)); +} + + +template +T scalar_gelu(T x) { + using opmath_t = reduced_fp_to_float_t; + const auto kAlpha = opmath_t(M_SQRT1_2); + return reduced_fp_to_float(x) * opmath_t(0.5) * (opmath_t(1) + std::erf(reduced_fp_to_float(x) * kAlpha)); +} + +template, bool> = true> +vec::Vectorized vectorized_gelu(vec::Vectorized x) { + const vec::Vectorized kAlphaVec(T(M_SQRT1_2)); + const vec::Vectorized kOneVec(T(1)); + const vec::Vectorized kPointFiveVec(T(0.5)); + return x * kPointFiveVec * (kOneVec + (x * kAlphaVec).erf()); +} + +template, bool> = true> +vec::Vectorized vectorized_gelu(vec::Vectorized x) { + auto [x0, x1] = at::vec::convert_to_float(x); + return at::vec::convert_from_float(vectorized_gelu(x0), vectorized_gelu(x1)); +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..743bbfdb7e800e6f8b0770787bb10d7aa24000e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using forward_2d_fn = void (*) ( + const TensorBase &output, + const TensorBase &input, + const TensorBase &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners); +using backward_2d_fn = void (*) ( + const TensorBase &grad_input, + const TensorBase &grad_grid, + const TensorBase &grad_output, + const TensorBase &input, + const TensorBase &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask); +DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel) +DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c513d128e23421a6c828d2f710732b04b2d1800a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h @@ -0,0 +1,85 @@ +#pragma once +#include +#include + +namespace at::native { + +inline bool is_constant_index(int ntensor, const int64_t* strides) { + AT_ASSERT(ntensor >= 3); + for (const auto arg : c10::irange(2, ntensor)) { + if (strides[arg] != 0) { + return false; + } + } + return true; +} + + +struct Indexer { + Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides, + IntArrayRef original_sizes, IntArrayRef original_strides) + : num_indexers(num_indexers) + , indexers(indexers) + , indexer_strides(indexer_strides) + , original_strides(original_strides.data()) + , original_sizes(original_sizes.data()) { + AT_ASSERT(static_cast(original_strides.size()) == num_indexers); + AT_ASSERT(static_cast(original_sizes.size()) == num_indexers); + } + + int64_t num_indexers; + char** indexers; + const int64_t* indexer_strides; + const int64_t* original_strides; + const int64_t* original_sizes; + + int64_t get(int64_t idx) { + int64_t offset = 0; + for (const auto j : c10::irange(num_indexers)) { + int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]]; + int64_t size = original_sizes[j]; + TORCH_CHECK_INDEX(value >= -size && value < size, + "index ", value, " is out of bounds for dimension ", j, " with size ", size); + if (value < 0) { + value += size; + } + offset += value * original_strides[j]; + } + return offset; + } +}; + +template +void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, + const func_t& f, bool serial_execution=false) +{ + int ntensor = iter.ntensors(); + // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE + // to make the whole available thread numbers get more balanced work load and a better cache location. + // The grain size here is chosen by the op benchmark to overcome the thread launch overhead + const int index_parallel_grain_size = 3000; + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride); + char* dst = data[0]; + char* src = data[1]; + if (is_constant_index(ntensor, strides)) { + // specialization for when every element uses the same index + int64_t offset = indexer.get(0); + for (const auto i : c10::irange(n)) { + f(dst + strides[0] * i, src + strides[1] * i, offset); + } + } else { + for (const auto i : c10::irange(n)) { + int64_t offset = indexer.get(i); + f(dst + strides[0] * i, src + strides[1] * i, offset); + } + } + }; + if (serial_execution) { + iter.serial_for_each(loop, {0, iter.numel()}); + } else { + iter.for_each(loop, index_parallel_grain_size); + } +} +} // at +// native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Intrinsics.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..f3b35328f1882729a9158eaed7eb2abf77097484 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Intrinsics.h @@ -0,0 +1,33 @@ +#pragma once + +#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__)) +/* Clang-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(_MSC_VER) +/* Microsoft C/C++-compatible compiler */ +#include +#if _MSC_VER <= 1900 +#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y]) +#endif +#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(__GNUC__) && defined(__ARM_NEON__) +/* GCC-compatible compiler, targeting ARM with NEON */ +#include +#elif defined(__GNUC__) && defined(__IWMMXT__) +/* GCC-compatible compiler, targeting ARM with WMMX */ +#include +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) +/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ +#include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel +#elif defined(__GNUC__) && defined(__SPE__) +/* GCC-compatible compiler, targeting PowerPC with SPE */ +#include +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IsContiguous.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IsContiguous.h new file mode 100644 index 0000000000000000000000000000000000000000..02d8f5dd78e40f2628031d68719a5059db869302 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IsContiguous.h @@ -0,0 +1,64 @@ +#pragma once + +namespace at::native { inline namespace CPU_CAPABILITY { + +// n: number of function arguments (arity) +// traits: function_traits (see FunctionTraits.h) +// s: index of scalar argument or -1 +template +struct IsContiguous { + static bool eval(const int64_t* strides) { + using type = typename traits::template arg::type; + return strides[stride_index] == (s == n ? 0 : sizeof(type)) && + IsContiguous::eval(strides); + } +}; + +// will be called when there is an output exists +template +struct IsContiguous<0, 0, traits, s> { + static bool eval(const int64_t* strides) { + return strides[0] == sizeof(typename traits::result_type); + } +}; + +// will be called when there is no output +template +struct IsContiguous<0, -1, traits, s> { + static bool eval(const int64_t* /*strides*/) { + return true; + } +}; + +// output and all inputs are contiguous +template < + typename traits, + std::enable_if_t>* = + nullptr> +static inline bool is_contiguous(const int64_t* strides) { + return IsContiguous::eval(strides); +} + +template >* = nullptr> +static inline bool is_contiguous(const int64_t* strides) { + return IsContiguous::eval(strides); +} + +// input at `s` is scalar (stride 0); output and other inputs are contiguous +// NB: output is typically at strides[0] so first input corresponds to s=1 +template >* = nullptr> +static inline bool is_contiguous_scalar(const int64_t* strides) { + static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); + return IsContiguous::eval(strides); +} + +template >* = nullptr> +static inline bool is_contiguous_scalar(const int64_t* strides) { + static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); + return IsContiguous::eval(strides); +} + +}} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogAddExp.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogAddExp.h new file mode 100644 index 0000000000000000000000000000000000000000..e2b80a648df6b11a99ceadff5488dd597af4f9ac --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogAddExp.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +// custom min and max to be used in logcumsumexp for complex arguments +template +std::pair, c10::complex> _logcumsumexp_minmax(c10::complex x, c10::complex y) { + if (at::_isnan(y)) { // either real is nan or imag is nan + return std::make_pair(y, y); + } else if (at::_isnan(x)) { // either real is nan or imag is nan + return std::make_pair(x, x); + } else { + return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x); + } +} + +template +scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) { + // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp + scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan + scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan + if (min != max || std::isfinite(min)) { + // nan will be propagated here + return std::log1p(std::exp(min - max)) + max; + } else { + // special case to correctly handle infinite cases + return x; + } +} + +template +c10::complex _log_add_exp_helper(const c10::complex& x, const c10::complex& y) { + auto [min, max] = _logcumsumexp_minmax(x, y); + auto min_real = std::real(min); + auto max_real = std::real(max); + + if (at::_isnan(min)) { // either real is nan or imag is nan + // handling the "infectious" NaNs + return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; + } else if (!std::isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + // handle the -inf case, the imaginary part here does not really matter as the exp(value) + // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined. + // It does not matter if we're taking the exp of this value + return min; + } else { + // handle the +inf case, we don't need the special precision for log1p for small values + // and to avoid producing nan in case of real(max) == real(min) == +inf + return std::log(std::exp(min) + std::exp(max)); + } + } else { + return std::log1p(std::exp(min - max)) + max; + } +} + +} // end namespace +} //end at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..b8af353e8866cf41561bfe9f888ce7fcacd85e2a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h @@ -0,0 +1,337 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { +template +int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) { + // Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the + // size of L1D cache on many processors. Some processors have 48 KB L1D cache + // nowadays, so maybe in the future, we can leverage the knowledge of a + // machine's L1D cache size. + int64_t MAX_CHUNK_SIZE = std::max( + 1, + grain_size / (sizeof(scalar_t) * dim_size)); + return std::min(MAX_CHUNK_SIZE, outer_size); +} + +template +void serial_vec_log_softmax_lastdim_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t dim_size, + int64_t chunk_size, + int64_t begin, + int64_t end) { + if (end <= begin) { + return; + } + using Vec = vec::Vectorized>; + // MSVC requires such a declaration of dynamic arrays + // Source: https://stackoverflow.com/a/33423538 + auto tmp_sum_scalar = std::make_unique(chunk_size); + auto max_input_arr = std::make_unique(chunk_size); + for (int64_t ii = begin; ii < end; ii += chunk_size) { + int64_t loop_end = chunk_size; + if (ii + chunk_size > end) { + loop_end = end - ii; + } + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + max_input_arr[j] = vec::reduce_all( + [](Vec& x, Vec& y) { return vec::maximum(x, y); }, + input_data, + dim_size); + } + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + scalar_t max_input = max_input_arr[j]; + tmp_sum_scalar[j] = vec::map_reduce_all( + [max_input](Vec x) { return (x - Vec(max_input)).exp(); }, + [](Vec x, Vec y) { return x + y; }, + input_data, + dim_size); + } + // See [Note AVX-SSE transitions] for why this should call the + // vectorized version (aside from perf improvements). + vec::map( + [](Vec x) { return x.log(); }, + tmp_sum_scalar.get(), + tmp_sum_scalar.get(), + loop_end); + for (const auto j : c10::irange(loop_end)) { + int64_t i = ii + j; + const scalar_t* input_data = input_data_base + i * dim_size; + scalar_t* output_data = output_data_base + i * dim_size; + scalar_t tmp_sum = tmp_sum_scalar[j]; + scalar_t max_input = max_input_arr[j]; + + // It's necessary to keep the order of the operations below. + // In some cases that input is large digits and the difference + // is small, if we compute `max_input` plus `tmp_sum` before, + // there would be a numerical problem. See an example in + // https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379 + vec::map( + [tmp_sum, max_input](Vec x) { + return x - Vec(max_input) - Vec(tmp_sum); + }, + output_data, + input_data, + dim_size); + } + } +} + +// Can't include ATen/Parallel.h. +// TODO: find a way to have only one copy of divup. +inline int64_t divup(int64_t x, int64_t y) { + return (x + y - 1) / y; +} + +template +std::pair vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) { + using Vec = vec::Vectorized; + int64_t MAX_CHUNK_SIZE = std::max(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size()); + MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size(); + int64_t CHUNK_SIZE = std::min(MAX_CHUNK_SIZE, inner_size); + int64_t num_chunks = divup(inner_size, CHUNK_SIZE); + return {CHUNK_SIZE, num_chunks}; +} + +template +std::enable_if_t>, void> +serial_vec_logsoftmax_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t inner_size, + int64_t chunk_size, + int64_t num_chunks, + int64_t dim_size, + int64_t begin, + int64_t end) { + using Vec = vec::Vectorized; + // thread local temp buffer which holds vertical reduction result: max and sum. + auto buffer = std::make_unique(chunk_size * 2); + scalar_t* input_max_data = buffer.get(); + scalar_t* tmp_sum_data = buffer.get() + chunk_size; + + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * chunk_size; + int64_t size = std::min(chunk_size, inner_size - inner_idx_begin); + + // init + Vec zero_vec = Vec(scalar_t(0)); + Vec min_vec = Vec(-std::numeric_limits::infinity()); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + min_vec.store(input_max_data + d0); + zero_vec.store(tmp_sum_data + d0); + } + for (; d0 < size; d0++) { + input_max_data[d0] = -std::numeric_limits::infinity(); + tmp_sum_data[d0] = scalar_t(0); + } + + // compute max + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d1); + Vec max_vec = Vec::loadu(input_max_data + d1); + max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec); + max_vec.store(input_max_data + d1); + } + for (; d1 < size; d1++) { + scalar_t data_val = input_ptr[d1]; + scalar_t max_val = input_max_data[d1]; + input_max_data[d1] = data_val > max_val ? data_val : max_val; + } + } + + // compute sum of (x - max).exp() + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d2); + Vec sum_vec = Vec::loadu(tmp_sum_data + d2); + Vec max_vec = Vec::loadu(input_max_data + d2); + sum_vec += (data_vec - max_vec).exp(); + sum_vec.store(tmp_sum_data + d2); + } + for (; d2 < size; d2++) { + scalar_t data_val = input_ptr[d2]; + scalar_t max_val = input_max_data[d2]; + tmp_sum_data[d2] += std::exp(data_val - max_val); + } + } + + // apply log + vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); + + // compute x - max - sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin; + const scalar_t* input_ptr = input_data_base + offset; + scalar_t* output_ptr = output_data_base + offset; + + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d3); + Vec max_vec = Vec::loadu(input_max_data + d3); + Vec sum_vec = Vec::loadu(tmp_sum_data + d3); + Vec out_vec = data_vec - max_vec - sum_vec; + out_vec.store(output_ptr + d3); + } + for (; d3 < size; d3++) { + output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]; + } + } + } +} + +template +std::enable_if_t>, void> +serial_vec_logsoftmax_range( + const scalar_t* input_data_base, + scalar_t* output_data_base, + int64_t inner_size, + int64_t chunk_size, + int64_t num_chunks, + int64_t dim_size, + int64_t begin, + int64_t end) { + using Vec = vec::Vectorized; + using fVec = vec::Vectorized; + auto buffer = std::make_unique(chunk_size * 2); + float* input_max_data = buffer.get(); + float* tmp_sum_data = buffer.get() + chunk_size; + + // thread local buffer that holds input data in float32 to save next 2 dtype conversion + auto input_buffer = std::make_unique(dim_size * chunk_size); + float* input_buffer_data = input_buffer.get(); + + // init + for (int64_t i = begin; i < end; i++) { + int64_t outer_idx = i / num_chunks; + int64_t k = i % num_chunks; + int64_t inner_idx_begin = k * chunk_size; + int64_t size = std::min(chunk_size, inner_size - inner_idx_begin); + + fVec zero_fvec = fVec(float(0)); + fVec min_fvec = fVec(-std::numeric_limits::infinity()); + int64_t d0 = 0; + for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) { + min_fvec.store(input_max_data + d0); + min_fvec.store(input_max_data + d0 + fVec::size()); + zero_fvec.store(tmp_sum_data + d0); + zero_fvec.store(tmp_sum_data + d0 + fVec::size()); + } + for (; d0 < size; d0++) { + input_max_data[d0] = -std::numeric_limits::infinity(); + tmp_sum_data[d0] = float(0); + } + + // compute max + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + + int64_t d1 = 0; + for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) { + Vec data_vec = Vec::loadu(input_ptr + d1); + auto [data_fvec0, data_fvec1] = vec::convert_to_float(data_vec); + fVec max_fvec0 = fVec::loadu(input_max_data + d1); + fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size()); + max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0); + max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1); + max_fvec0.store(input_max_data + d1); + max_fvec1.store(input_max_data + d1 + fVec::size()); + + // cache the 'converted' float input + data_fvec0.store(input_buffer_ptr + d1); + data_fvec1.store(input_buffer_ptr + d1 + fVec::size()); + } + for (; d1 < size; d1++) { + float data_val = float(input_ptr[d1]); + float max_val = input_max_data[d1]; + input_max_data[d1] = data_val > max_val ? data_val : max_val; + input_buffer_ptr[d1] = data_val; + } + } + + // compute sum of (x - max).exp() + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + + int64_t d2 = 0; + for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) { + fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2); + fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size()); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size()); + fVec max_fvec0 = fVec::loadu(input_max_data + d2); + fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size()); + sum_fvec0 += (data_fvec0 - max_fvec0).exp(); + sum_fvec1 += (data_fvec1 - max_fvec1).exp(); + sum_fvec0.store(tmp_sum_data + d2); + sum_fvec1.store(tmp_sum_data + d2 + fVec::size()); + } + for (; d2 < size; d2++) { + float data_val = input_buffer_ptr[d2]; + float max_val = input_max_data[d2]; + tmp_sum_data[d2] += std::exp(data_val - max_val); + } + } + + // apply log + vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size); + + // compute x - max - sum + for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) { + float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size; + scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size + + dim_idx * inner_size + inner_idx_begin; + + int64_t d3 = 0; + for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) { + fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3); + fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size()); + fVec max_fvec0 = fVec::loadu(input_max_data + d3); + fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size()); + fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3); + fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size()); + fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0; + fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1; + Vec out_vec = vec::convert_from_float(out_fvec0, out_fvec1); + out_vec.store(output_ptr + d3); + } + for (; d3 < size; d3++) { + output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]); + } + } + } +} // namespace CPU_CAPABILITY +}} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Loops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Loops.h new file mode 100644 index 0000000000000000000000000000000000000000..5715fd8f047f29430ddf8becccb91666d1d0242a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Loops.h @@ -0,0 +1,395 @@ +#pragma once + +// This file provides two functions to help write elementwise kernels: +// +// cpu_kernel(TensorIterator iter, ) +// cpu_kernel_vec(TensorIterator iter, , ) +// +// Both functions may generate vectorized code. The cpu_kernel implementation +// relies on the compiler's auto-vectorization. The cpu_kernel_vec +// implementation uses x86 SIMD intrinsics when available. These functions +// are only intended to be used in the ATen/native/cpu subdirectory, since files +// in other directories are not compiled with AVX/AVX2 enabled. See README.md +// for more details. +// +// For example, to write a multiplication kernel for float: +// +// cpu_kernel(iter, [](float a, float b) { return a * b; }); +// +// Or you may write: +// +// cpu_kernel_vec(iter, +// [](float a, float b) { return a * b; }, +// [](Vectorized a, Vectorized b) { return a * b; }); +// +// See BinaryOpsKernel.cpp for the complete implementation +// +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at::native { inline namespace CPU_CAPABILITY { + +using namespace vec; + +template +typename traits::ArgsTuple +dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, + std::index_sequence) { + return std::make_tuple( + c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template +typename traits::ArgsTuple +dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) { + using Indices = std::make_index_sequence; + return dereference_impl(data, strides, i, Indices{}); +} + +template +typename traits::ArgsTuple +dereference_vec_impl(char* C10_RESTRICT data[], + const typename traits::result_type& opt_scalar, + size_t S, + int64_t i, + std::index_sequence) { + using Vec = typename traits::result_type; + using scalar_t = typename Vec::value_type; + return std::make_tuple( + S == INDEX + 1 ? + opt_scalar : + Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...); +} + +template +typename traits::ArgsTuple +dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) { + using Indices = std::make_index_sequence; + return dereference_vec_impl(data, opt_scalar, S, i, Indices{}); +} + +template ::result_type>>* = nullptr> +inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + using result_type = typename traits::result_type; + for (; i < n; i++) { + result_type* out_ptr = (result_type*)(data[0] + i * strides[0]); + *out_ptr = c10::guts::apply(op, dereference( + &data[1], + &strides[1], + i)); + } +} + +template ::result_type>>* = nullptr> +inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + for (; i < n; i++) { + c10::guts::apply(op, dereference( + &data[0], + &strides[0], + i)); + } +} + +// Basic loop operation (one output, N inputs). May be auto-vectorized +// by the compiler. Supports inputs and outputs of different types. +template +inline void +basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + constexpr int ntensors = traits::arity + 1; + + // Copying strides to temporary array helps auto vectorization in older GCC + // versions. + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = strides_[arg]; + } + + execute_op(data, strides, i, n, std::forward(op)); +} + +// the recursive variadic template for iterating over the returned tuple +template +struct TupleOutput { + static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i, + const T &tuple) { + TupleOutput::handle(data, strides, i, tuple); + + auto output = std::get(tuple); + using output_type = decltype(output); + output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]); + *out_ptr = output; + } +}; + +// Base case for the above recursive template +template +struct TupleOutput { + static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i, + const T &tuple) { + auto output = std::get<0>(tuple); + using output_type = decltype(output); + output_type* out_ptr = (output_type *)(data[0] + i * strides[0]); + *out_ptr = output; + } +}; + +template +void handle_tuple_outputs(char* C10_RESTRICT data[], + const int64_t* strides, + int64_t i, + const std::tuple &tuple) { + TupleOutput::handle(data, strides, i, tuple); +} + +// Loop operation for `cpu_kernel_multiple_outputs`. +// 1. Use `c10::guts::apply` to make dynamic method invocation +// for the lambda passed in `cpu_kernel_multiple_outputs`. +// 2. Iterate over the members of the returned tuple, set the corresponding +// output tensor by the tuple member in `handle_tuple_outputs` function. +template +inline void +multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { + using traits = function_traits; + + using result_type = typename traits::result_type; + constexpr int num_outputs = std::tuple_size_v; + constexpr int ntensors = traits::arity + num_outputs; + + // Copying strides to temporary array helps auto vectorization in older GCC + // versions. + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = strides_[arg]; + } + + for (; i < n; i++) { + auto output = c10::guts::apply(op, dereference( + &data[num_outputs], + &strides[num_outputs], + i)); + handle_tuple_outputs(data, strides, i, output); + } +} + +// Explicitly vectorized loop implementation. All inputs and outputs must be +// the same type and contiguous with one exception: a single input may be +// a scalar (stride 0). It's position is indicated by the argument `S`. If `S` +// is 0, then there are no scalar inputs. +template +inline void +vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) { + using traits = function_traits; + using scalar_t = typename function_traits::result_type; + using Vec = Vectorized; + constexpr int ntensors = traits::arity + 1; + + char* C10_RESTRICT data[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + data[arg] = data_[arg]; + } + + Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0)); + int64_t i = 0; + for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) { + auto args1 = dereference_vec(&data[1], opt_scalar, S, i); + auto args2 = dereference_vec(&data[1], opt_scalar, S, i + Vec::size()); + auto out1 = c10::guts::apply(vop, std::move(args1)); + auto out2 = c10::guts::apply(vop, std::move(args2)); + out1.store(data[0] + i * sizeof(scalar_t)); + out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t)); + } + if (i < n) { + int64_t strides[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t); + } + basic_loop(data, strides, i, n, std::forward(op)); + } +} + + +template +inline void unroll_contiguous_scalar_checks( + const int64_t* /*strides*/, + std::index_sequence<>, + cb_t&& cb) { + cb(0); +} + +template +inline void unroll_contiguous_scalar_checks( + const int64_t* strides, + std::index_sequence, + cb_t&& cb) { + if (is_contiguous_scalar(strides)) { + cb(INDEX0 + 1); + } else { + unroll_contiguous_scalar_checks(strides, std::index_sequence{}, std::forward(cb)); + } +} + +template +struct VectorizedLoop2d { + op_t op; + vop_t vop; + + using traits = function_traits; + static constexpr int ntensors = traits::arity + 1; + using data_t = std::array; + + VectorizedLoop2d(op_t op, vop_t vop): + op(std::move(op)), vop(std::move(vop)) {} + + static void advance(data_t &data, const int64_t *outer_strides) { + for (const auto arg : c10::irange(data.size())) { + data[arg] += outer_strides[arg]; + } + } + + void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) { + data_t data; + std::copy_n(base, ntensors, data.data()); + const int64_t *outer_strides = &strides[ntensors]; + + if (is_contiguous(strides)) { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { + vectorized_loop(data.data(), size0, 0, op, vop); + advance(data, outer_strides); + } + } else { + using Indices = std::make_index_sequence; + unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t idx) { + if (idx) { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { + vectorized_loop(data.data(), size0, idx, op, vop); + advance(data, outer_strides); + } + } else { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { + basic_loop(data.data(), strides, 0, size0, op); + advance(data, outer_strides); + } + } + }); + } + } +}; + +template +VectorizedLoop2d make_vectorized_loop2d( + op_t &&op, vop_t &&vop) { + return VectorizedLoop2d(std::forward(op), std::forward(vop)); +} + +template +void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.for_each([&](char** data, const int64_t* strides, int64_t n) { + // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that + // iter.for_each is ever sending to the loop lambda + basic_loop(data, strides, 0, n, op); + }, grain_size); + iter.cast_outputs(); +} + +// This function helps write elementwise kernels that requires multiple outputs. +// It follows the similar structure of cpu_kernel. +// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is +// manipulated to handle multiple return values. +// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`) +// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`. +// The `gpu_kernel_multiple_outputs` is also implemented without this check, +// We could extend `needs_dynamic_casting` to support both `std::tuple` and +// `thrust::tuple` in the future. +template +void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + + iter.for_each([&](char** data, const int64_t* strides, int64_t n) { + multiple_outputs_loop(data, strides, 0, n, op); + }, grain_size); + iter.cast_outputs(); +} + +template +void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU, but some kernels (like Fill) + // explicitly dynamic_cast, so we give the opt-out of checking. + if constexpr (check_dynamic_cast) { + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + } + + iter.for_each(make_vectorized_loop2d(std::forward(op), std::forward(vop)), grain_size); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) { + using traits = function_traits; + constexpr bool result_void = std::is_void_v; + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity && + ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1))); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) { + basic_loop(data, strides, 0, n, op); + }, range); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) { + cpu_serial_kernel(iter, std::forward(op), {0, iter.numel()}); +} + +template +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) { + using traits = function_traits; + // this could be extended to work with void return types + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + // dynamic casting not currently supported on CPU + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + iter.serial_for_each(make_vectorized_loop2d(std::forward(op), std::forward(vop)), range); + iter.cast_outputs(); +} + +template +void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) { + cpu_serial_kernel_vec(iter, std::forward(op), std::forward(vop), {0, iter.numel()}); +} + +}} // namespace at::native:: diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c9a079cc203f26f5da2123300ff251203feb3835 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class Tensor; + +namespace native { + +using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel) +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel) + +}} // at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..abdb4945c98c91209ec4f6bb9cb62092db74b2e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t); +DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel) +DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel) + +} // at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Reduce.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..6c9efbb0f6e7f5d849a3e2ae48ca22ab147d6915 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Reduce.h @@ -0,0 +1,310 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace at::native { inline namespace CPU_CAPABILITY { + +using namespace vec; + +#define VEC_LOOP_HEADER(func_t, data) \ + using scalar_t = typename function_traits::result_type; \ + using Vec = Vectorized; \ + char* out_ptr = data[0]; \ + (void) out_ptr; + +// reduction that is contiguous over the input in dim 0 +template +inline bool is_contiguous_reduction(const int64_t* strides) { + return strides[0] == 0 && + strides[1] == sizeof(typename traits::arg2_t); +} + +// reduction that is contiguous over the input in dim 1 +template +inline bool is_outer_reduction(const int64_t* strides) { + return strides[0] == 0 && + strides[2] == sizeof(typename traits::result_type) && + strides[3] == sizeof(typename traits::arg2_t); +} + +template +inline void vectorized_reduction(char** data, int64_t n, int64_t stride, + func_t op, vec_func_t vop, bool reduce) { + VEC_LOOP_HEADER(func_t, data) + const char* in1_ptr = data[1]; + Vec acc[4]; + for (const auto j : c10::irange(4)) { + acc[j] = Vec::loadu(in1_ptr + j * Vec::size() * sizeof(scalar_t)); + } + for (const auto i : c10::irange(1, n)) { + const char* ptr = in1_ptr + stride * i; + acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t)))); + acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t)))); + acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t)))); + acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t)))); + } + if (reduce) { + scalar_t buffer[Vec::size()]; + acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3])); + acc[0].store(buffer); + for (const auto j : c10::irange(1, Vec::size())) { + buffer[0] = op(buffer[0], buffer[j]); + } + auto dst = (scalar_t*)out_ptr; + *dst = op(*dst, buffer[0]); + } else { + for (const auto j : c10::irange(4)) { + auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t); + acc[j] = vop(acc[j], Vec::loadu(dst)); + acc[j].store(dst); + } + } +} + +template +inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { + for ([[maybe_unused]] const auto j : c10::irange(n)) { + f(); + data[0] += strides[0]; + data[1] += strides[1]; + } +} + +// computes the reduction out = op(out, in) +template +inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { + VEC_LOOP_HEADER(func_t, data) + constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); + int64_t count = n / (4 * Vec::size()); + if (count > 0) { + vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true); + } + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t strides[] = { 0, 0, sizeof(scalar_t) }; + basic_loop(ptrs, strides, count * 4 * Vec::size(), n, op); +} + +// computes the reduction out = op(out, in) +template +inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { + VEC_LOOP_HEADER(func_t, data) + + // reduce down each column of 4 * Vec::size() elements. + constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); + int64_t outer_stride[2] = { vector_stride, vector_stride }; + UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] { + vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false); + }); + + // reduce down the remaining columns + int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) }; + int64_t remaining = size1 % (4 * Vec::size()); + UNARY_OUTER_LOOP(data, step, remaining, [&] { + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t strides[] = { 0, 0, inner_stride }; + basic_loop(ptrs, strides, 0, size0, op); + }); +} + +template +static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) { + // static_assert(std::is_same_v, "data types must match"); + if (index < num_outputs) { + char *out = (char *) iter.data_ptr(index); + *(res_t *) out = result; + } +} + +template +static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) { + AT_ASSERT(num_outputs == 1); + set_result(0, result, iter, num_outputs); +} + +template +inline std::enable_if_t +for_each_in_tuple(const std::tuple& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) { + return i; +} + +template +inline std::enable_if_t +for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { + if (i < (size_t)num_outputs) { + set_result(i, std::get(t), iter, num_outputs); + return for_each_in_tuple(t, iter, num_outputs); + } + return i; +} + +template +static void set_results(const std::tuple& result, const TensorIteratorBase &iter, const int num_outputs) { + AT_ASSERT(num_outputs >= 1); + std::size_t result_size = for_each_in_tuple(result, iter, num_outputs); + AT_ASSERT((size_t)num_outputs == result_size); +} + +template +struct all_same : std::conjunction< + std::is_same... +> {}; + +// data_t is the input/output data type. +// acc_t is a type that contains all the necessary data +// to continue reducing. +// index_t is a one-dimensional index +// +// ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy +// the following. +// reduce: (acc_t, data_t, index_t) -> acc_t adds one data point to the accumulated value. +// combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one. +// project: acc_t -> out_t finishes the reduction, getting the required output. +// +// Additionally, acc_t must be default-constructible: +// acc_t {} is an identity for combine, +// and project(acc_t {}) is the value of the operation on zero elements. +// +// The point of `combine` is to support parallelization - +// the idea is to one sequence of `reduce` calls per thread of execution, +// and then to combine them at the end with `combine`. +// +// If there is more than one output element, +// our parallelization strategy is to use one thread for each of them, +// which means that `combine` will never be called. +// +// If, on the other hand, there is only one, then we split the input into +// into several pieces, reduce each separately, and then combine them. + +template +void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { + using rf_t = decltype(&ops_t::reduce); + using cf_t = decltype(&ops_t::combine); + using pf_t = decltype(&ops_t::project); + using r_traits = binary_function_traits; + using c_traits = binary_function_traits; + using p_traits = unary_function_traits; + using acc_t = typename p_traits::arg1_t; + using data_t = typename r_traits::arg2_t; + static_assert( + all_same< + acc_t, + init_t, + typename r_traits::arg1_t, + typename r_traits::result_type, + typename c_traits::arg1_t, + typename c_traits::arg2_t, + typename c_traits::result_type>::value, + "all accumulate types must match"); + static_assert( + std::is_default_constructible_v, + "the accumulate type must be default-constructible" + ); + const int num_outputs = iter.noutputs(); + iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) { + auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t { + int ntensors = sub_iter.ntensors(); + sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) { + AT_ASSERT(ntensors - num_outputs == 1); + char *in = data[ntensors - 1]; + int64_t stride = strides[ntensors - 1]; + for (const auto i : c10::irange(size)) { + acc = ops.reduce(acc, c10::load(in), begin + i); + in += stride; + } + }, {begin, end}); + return ops.translate_idx(acc, sub_iter.view_offsets()[0]); + }; + acc_t total_acc = init; + auto numel = sub_iter.numel(); + if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 || + at::in_parallel_region()) { + total_acc = reduction_body(total_acc, 0, numel); + } else { + int max_threads = at::get_num_threads(); + AT_ASSERT(max_threads > 0); + static_assert( + !std::is_same_v, + "Concurrently modifying different references into std::vector is UB." + ); + std::vector buffer((unsigned)max_threads, init); + at::parallel_for(0, numel, internal::GRAIN_SIZE, + [&](int64_t begin, int64_t end) { + auto& acc = buffer[at::get_thread_num()]; + acc = reduction_body(acc, begin, end); + } + ); + for (const auto i : c10::irange(max_threads)) { + total_acc = ops.combine(total_acc, buffer[i]); + } + } + set_results(ops.project(total_acc), sub_iter, num_outputs); + }); +} + +template +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { + using traits = binary_function_traits; + static_assert( + all_same< + typename traits::result_type, + typename traits::arg1_t, + typename traits::arg2_t>::value, + "all types must match"); + + iter.output_base().fill_(ident); + iter.parallel_reduce([&](char** data, const int64_t* strides, int64_t size0, int64_t size1) { + int64_t outer_strides[] = { strides[2], strides[3] }; + if (is_contiguous_reduction(strides)) { + // input is contiguous in dim 0, output is reduced in dim 0 + UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { + vectorized_inner_reduction(data, size0, op, vop); + }); + } else if (is_outer_reduction(strides)) { + // input and output are contiguous in dim 1 + int64_t inner_stride = strides[1]; // stride of input in dim 0 + vectorized_outer_reduction(data, inner_stride, size0, size1, op, vop); + } else { + UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t inner_strides[3] = { strides[0], strides[0], strides[1] }; + basic_loop(ptrs, inner_strides, 0, size0, op); + }); + } + }); +} + +// when reduction is on most inner dimension (dim 0 in TensorIterator) +// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim` +// can be used. +inline bool is_reduce_lastdim(TensorIteratorBase& iter) { + return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0) + && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1); +} + +template +void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) { + auto shape = iter.shape(); + int64_t dim_size = shape[0]; + int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size); + TensorIterator sub_iter(iter); + // create sub iterator to parallel on all non-reduce-dims + sub_iter.narrow(0, 0, 1); + auto loop = [&](char** data, const int64_t* strides, int64_t size) { + char* out = data[0]; + char* in = data[1]; + for (int64_t i = 0; i < size; ++i) { + reduce_op(out, in, dim_size); + out += strides[0]; + in += strides[1]; + } + }; + sub_iter.for_each(loop, grain_size); +} + +}} // namespace at::native:: diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..fd7c4a2750a6c953fb30d29cd13c6924c99bec39 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h @@ -0,0 +1,238 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +using namespace vec; + +#define AT_DISPATCH_REDUCTION_TYPES(op, ...) \ + [&] { \ + switch (op) { \ + case ReductionType::SUM: { \ + static constexpr auto reduce = ReductionType::SUM; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MEAN: { \ + static constexpr auto reduce = ReductionType::MEAN; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MIN: { \ + static constexpr auto reduce = ReductionType::MIN; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::MAX: { \ + static constexpr auto reduce = ReductionType::MAX; \ + return __VA_ARGS__(); \ + } \ + case ReductionType::PROD: { \ + static constexpr auto reduce = ReductionType::PROD; \ + return __VA_ARGS__(); \ + } \ + } \ + }() + +template +inline vec_scalar_t init_value() { + using acc_t = vec_scalar_t; + acc_t val; + if (reduce == ReductionType::SUM || + reduce == ReductionType::MEAN) { + val = static_cast(0); + } else if (reduce == ReductionType::PROD) { + val = static_cast(1); + } else if (reduce == ReductionType::MAX) { + val = -std::numeric_limits::infinity(); + } else { + TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); + val = std::numeric_limits::infinity(); + } + return val; +} + +template +inline vec_scalar_t init_value(const std::optional& initial) { + using acc_t = vec_scalar_t; + if (initial.has_value()) { + return initial.value().to(); + } else { + return init_value(); + } +} + +template +inline void init(scalar_t* out, int64_t size, const vec_scalar_t& val) { + using Vec = Vectorized>; + map( + [val](Vec x) { return Vec(val); }, + out, + out, + size); +} + +template +inline void init(scalar_t* out, int64_t size, const std::optional& initial) { + using acc_t = vec_scalar_t; + acc_t val = init_value(initial); + init(out, size, val); +} + +// overload with `include_self`, used by scatter_reduce +template +inline void init(scalar_t* out, int64_t size, bool include_self = false) { + using acc_t = vec_scalar_t; + if (!include_self) { + acc_t val = init_value(); + init(out, size, val); + } +} + +template +inline void _init(scalar_t* self_ptr, at::opmath_type* buffer_ptr, int64_t size, bool include_self) { + if (!include_self) { + init, reduce>(buffer_ptr, size, include_self); + } else { + vec::convert(self_ptr, buffer_ptr, size); + } +} + +template +inline std::enable_if_t, scalar_t> +_max(const scalar_t& x, const scalar_t& y) { + return at::_isnan(y) ? y : std::max(x, y); +} + +template +inline Vectorized _max(const Vectorized& x, const Vectorized& y) { + // vec::maximum propagates NaN + return vec::maximum(x, y); +} + +template +inline std::enable_if_t, Vec2> +_max(const vec_t& x, const vec_t& y) { + // vec::maximum propagates NaN + return maximum(x, y); +} + +template +inline std::enable_if_t, scalar_t> +_min(const scalar_t& x, const scalar_t& y) { + return at::_isnan(y) ? y : std::min(x, y); +} + +template +inline Vectorized _min(const Vectorized& x, const Vectorized& y) { + // vec::minimum propagates NaN + return vec::minimum(x, y); +} + +template +inline std::enable_if_t, Vec2> +_min(const vec_t& x, const vec_t& y) { + // vec::minimum propagates NaN + return minimum(x, y); +} + +template , int> = 0> +inline void map_acc( + const Op& vec_fun, + accumut* output_data, + const accumut* input_data, + const scalar_t* input_data2, + int64_t size) { + using Vec = vec::Vectorized; + using aVec = vec::Vectorized; + int64_t d = 0; + constexpr int64_t kVecSize = Vec::size(); + constexpr int64_t kaVecSize = aVec::size(); + for (d = 0; d < size - (size % kVecSize); d += kVecSize) { + Vec data2_vec = Vec::loadu(input_data2 + d); + auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); + aVec input_vec0 = aVec::loadu(input_data + d); + aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize); + vec_fun(input_vec0, data2_avec0).store(output_data + d); + vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize); + } + if (size - d > 0) { + int64_t tail_size = size - d; + Vec data2_vec = Vec::loadu(input_data2 + d, tail_size); + auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec); + if (tail_size > kaVecSize) { + aVec input_vec0 = aVec::loadu(input_data + d); + aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize); + vec_fun(input_vec0, data2_avec0).store(output_data + d); + vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize); + } else { + aVec input_vec0 = aVec::loadu(input_data + d, tail_size); + vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size); + } + } +} + +// for Max and Min, propagate NaN: +template +inline T update(const T& x, const T& y) { + if (reduce == ReductionType::SUM || + reduce == ReductionType::MEAN) { + return x + y; + } else if (reduce == ReductionType::PROD) { + return x * y; + } else if (reduce == ReductionType::MAX) { + return _max(x, y); + } else { + TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); + return _min(x, y); + } +} + +template +inline void update(scalar_t* out, const scalar_t* data, int64_t K) { + using Vec = vec::Vectorized>; + map2( + [](Vec x, Vec y) { return update(x, y); }, + out, + out, + data, + K); +} + +template , int> = 0> +inline void update(at::opmath_type* out, const scalar_t* data, int64_t K) { + using opmath_t = at::opmath_type; + using Vec = vec::Vectorized; + map_acc( + [](Vec x, Vec y) { return update(x, y); }, + out, + out, + data, + K); +} + +template +inline void write(scalar_t* out, int64_t count, int64_t K) { + using Vec = vec::Vectorized>; + if (reduce == ReductionType::MEAN) { + if (count > 0) { + vec::map( + [count](Vec x) { return x / Vec(count); }, + out, + out, + K); + } + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..70c9fb9eb20316c32087cb58f5d7fcc0c0f203d2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { +#if !defined(C10_MOBILE) +using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int); +DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub) + +using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int); +DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub) + +using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t); +DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub) + +using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t); +DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub) + +inline namespace CPU_CAPABILITY { +float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len); +float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len); +} // inline namespace CPU_CAPABILITY +#endif // !defined(C10_MOBILE) +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b5081e1822455fcc0927e258222a9a56ee94051c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace at::native { + +using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&); + +DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub) + +} // at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..e95de13ac0bb9f399aee0ac664e1200dacc845a8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h @@ -0,0 +1,146 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native::detail { + +struct InputMeta { + void* data_ptr; + int64_t inner_size; + + InputMeta(const Tensor& t, int64_t dim, int64_t inner) + : data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {} +}; + +// This kernel is used by two TensorList types: +// 1. stack_serial_kernel uses at::ArrayRef +// 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with +// ProcessedNodeInputWrapper. +// When making changes, make sure that they are compatible with both types! +template +void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + dim >= 0 && dim <= result.dim(), + "dim out of range in stack_serial_kernel_impl"); + int64_t outer = + result.numel() / (result.sizes()[dim] * result.strides()[dim]); + scalar_t* result_data = result.data_ptr(); + int64_t ninputs = tensors.size(); + std::vector inputs; + inputs.reserve(ninputs); + for (const auto& tensor : tensors) { + inputs.emplace_back(tensor, dim, tensor.strides()[dim]); + } + + using Vec = vec::Vectorized; + scalar_t* result_ptr = result_data; + for (const auto i : c10::irange(outer)) { + for (const auto j : c10::irange(ninputs)) { + int64_t local_inner = inputs[j].inner_size; + scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner; + + if (local_inner < Vec::size()) { + for (const auto k : c10::irange(local_inner)) { + result_ptr[k] = input_ptr[k]; + } + } else { + vec::map( + [](Vec x) { return x; }, result_ptr, input_ptr, local_inner); + } + result_ptr += local_inner; + } + } +} + +// Checks to see whether native stack can be invoked under these conditions: +// - result and input tensors are contiguous +// - only one thread is used +// - no type promotion has to occur +// - tensors dtype is Double or Float +template +bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) { + TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors"); + const Tensor& first_tensor = tensors[0]; + // stack dimension should be in range [0,firstTensor.dim()) + // dim == firstTensor.dim() is a valid input, but it is handled by default code path + // that uses unsqueeze + if (dim >= first_tensor.dim()) return false; + // Native stack doesn't apply any tensor is skipped. + if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false; + // there should be no type promotion + if (result.dtype() != first_tensor.dtype()) return false; + + auto first_tensor_mem_format = first_tensor.suggest_memory_format(); + ScalarType dtype = first_tensor.scalar_type(); + + if (!result.is_contiguous(first_tensor_mem_format)) { + return false; + } + + // fast path only works for Double and Float + if (dtype != ScalarType::Double && dtype != ScalarType::Float) { + return false; + } + + // check remainder of inputs +#ifndef STRIP_ERROR_MESSAGES + auto const &first_tensor_shape = first_tensor.sizes(); +#endif + for (const auto i : c10::irange(1, tensors.size())) { + auto const &tensor = tensors[i]; + TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(), + "stack expects each tensor to be equal size, but got ", first_tensor_shape, + " at entry 0 and ", tensor.sizes(), " at entry ", i); + + // every tensor must be contiguous + // tensor sizes and strides must be the same + // there should be no type promotion + if (!tensor.is_contiguous(first_tensor_mem_format) || + tensor.strides() != first_tensor.strides() || + tensor.dtype() != dtype) { + return false; + } + } + + // fast native stack should only be used when it is not worth using multiple threads + // or there is only one thread. Note that we aren't checking result.numel() here because + // it may not have been resized and we want to defer that cost till later. + int64_t numel_in_stack = first_tensor.numel() * tensors.size(); + return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1; +} + +template +struct CanUseNativeSerialStack; + +template +struct CanUseNativeSerialStack { + static bool call(Tensor& result, TensorListType tensors, int64_t dim) { + // Inputs cannot alias the output tensor + for (const auto i : c10::irange(tensors.size())) { + auto lap = at::get_overlap_status(result, tensors[i]); + TORCH_CHECK(lap != at::MemOverlapStatus::Partial && + lap != at::MemOverlapStatus::Full, 0, + "unsupported operation: the input tensors cannot refer to any of the " + "output memory locations. Found overlap in input tensor ", i); + } + + return can_use_native_serial_stack_impl(result, tensors, dim); + } +}; + +template +struct CanUseNativeSerialStack { + static bool call(Tensor& result, TensorListType tensors, int64_t dim) { + return can_use_native_serial_stack_impl(result, tensors, dim); + } +}; + +} // namespace at::native::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8bc86a036e2ef501824f51bb704bce06856c8e1a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using forward_fn = void (*)(const Tensor&, const Tensor&); +using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&); + +DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel) +DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel) +DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel) +DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel) + +using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t); +using backward_fn_with_dim = + void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t); + +DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel) +DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel) +DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel) +DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel) +} +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..336d30a941d2703cf0246b86621b73e35d17cf86 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); + +DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub) +DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub) +DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub) +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub) +DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub) +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub) + +} // at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/StackKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/StackKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3ff30c4bc6310f2899856e6d7032ded1662f1271 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/StackKernel.h @@ -0,0 +1,12 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +#pragma once + +#include +#include + +namespace at::native { + +using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t); +DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h new file mode 100644 index 0000000000000000000000000000000000000000..5b545509b1d99e6bbe7ca6ed40d312b9f4c3a5d5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -0,0 +1,1376 @@ +/* +The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh + +Pillow is the friendly PIL fork. It is + + Copyright © 2010-2022 by Alex Clark and contributors + +Like PIL, Pillow is licensed under the open source HPND License +*/ + +// This code is heavily inspired from PILLOW-SIMD's implementation: +// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c + +#pragma once +#ifdef CPU_CAPABILITY_AVX2 +// TODO: This file only supports AVX2. We could split the AVX kernels into +// smaller logical blocks in order to port them into the Vec.h logic. This would +// allow to support other vectorization architectures and perhaps also support +// the non-vectorized fallback (we'd need to make sure it's not slower than the +// current fallback). + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + + +namespace { + +static inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) { + int32_t v; + if (i32_aligned) { + v = *(const int32_t*)ptr; + } else { + std::memcpy(&v, ptr, 4); + } + return _mm_cvtsi32_si128(v); +} + +static inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) { + return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned)); +} + +static inline void _write_endline_rgb_as_uint32( + uint8_t* C10_RESTRICT output, + uint32_t data +) { + // data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...) + // Here we explicitly set X as R1 + uint8_t* data_ptr = reinterpret_cast(&data); + data_ptr[3] = output[3]; + std::memcpy(output, data_ptr, 4); +} + +at::Tensor unpack_rgb(const at::Tensor& packed_tensor) { + // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into + // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded + // into as 32 bits. This generalizes to num_channels <= 4 and also works for + // non-channels_last tensors. + + const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr(); + auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2); + auto num_channels = packed_tensor.size(0); + + constexpr int rgba_size = 4; + auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte)); + uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr(); + + auto stride_i = packed_tensor.stride(2); + auto stride_j = packed_tensor.stride(0); + + for (const auto i : c10::irange(num_pixels)) { + for (const auto j : c10::irange(rgba_size)) { + unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0; + } + } + return unpacked_tensor; +} + +void pack_rgb( + const at::Tensor& unpacked_tensor, // IN + const at::Tensor& packed_tensor // OUT +) { + // Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout. + + uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr(); + uint8_t* packed = (uint8_t*)packed_tensor.data_ptr(); + auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2); + auto num_channels = packed_tensor.size(0); + + auto unpacked_increment = unpacked_tensor.size(0); + auto packed_increment = packed_tensor.stride(2); + auto packed_stride = packed_tensor.stride(0); + + TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4); + + for ([[maybe_unused]] const auto i : c10::irange(num_pixels)) { + for (const auto j : c10::irange(num_channels)) { + packed[j * packed_stride] = unpacked[j]; + } + unpacked += unpacked_increment; + packed += packed_increment; + } +} + +void ImagingResampleHorizontalConvolution8u4x( + uint8_t* C10_RESTRICT lineOut0, + uint8_t* C10_RESTRICT lineOut1, + uint8_t* C10_RESTRICT lineOut2, + uint8_t* C10_RESTRICT lineOut3, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn0, + const uint8_t* C10_RESTRICT lineIn1, + const uint8_t* C10_RESTRICT lineIn2, + const uint8_t* C10_RESTRICT lineIn3, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line); + +void ImagingResampleHorizontalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line); + +void ImagingResampleVerticalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + const uint8_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision, + int64_t num_channels); + +template +void ImagingResampleHorizontal( + const at::Tensor & unpacked_output, + const at::Tensor & unpacked_input, + int ksize, + const std::vector& horiz_indices_weights, + unsigned int horiz_weights_precision) { + + // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs. + + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1] + // oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1] + // oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1] + // + + // TODO: we may want to merge that into the fallback code (currently called + // basic_loop_aa_horizontal) + // Although this may not be needed if / when we port all this code to use + // Vec.h since this would potentially give us another fall-back implem + + const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr()); + + auto xout = unpacked_output.size(2); + auto yout = unpacked_output.size(1); + auto xin = unpacked_input.size(2); + TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0)); + + const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr(); + const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr(); + + uint8_t* unpacked_output_p = unpacked_output.data_ptr(); + const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr(); + + int64_t yy = 0; + auto xout_stride = xout * num_channels; + auto xin_stride = xin * num_channels; + for (; yy < yout - 3; yy += 4) { + ImagingResampleHorizontalConvolution8u4x( + unpacked_output_p + yy * xout_stride, + unpacked_output_p + (yy + 1) * xout_stride, + unpacked_output_p + (yy + 2) * xout_stride, + unpacked_output_p + (yy + 3) * xout_stride, + xout, + unpacked_input_p + yy * xin_stride, + unpacked_input_p + (yy + 1) * xin_stride, + unpacked_input_p + (yy + 2) * xin_stride, + unpacked_input_p + (yy + 3) * xin_stride, + xin, + idx_ptr_xmin, + idx_ptr_size, + kk, + ksize, + horiz_weights_precision, + num_channels, + yy + 3 == yout - 1); + } + for (; yy < yout; yy++) { + ImagingResampleHorizontalConvolution8u( + unpacked_output_p + yy * xout_stride, + xout, + unpacked_input_p + yy * xin_stride, + xin, + idx_ptr_xmin, + idx_ptr_size, + kk, + ksize, + horiz_weights_precision, + num_channels, + yy == yout - 1); + } +} + +void ImagingResampleVertical( + const at::Tensor & unpacked_output, + const at::Tensor & unpacked_input, + int ksize, + const std::vector& vert_indices_weights, + unsigned int vert_weights_precision) { + + // Interpolation vertical pass: we compute y-axis interpolation outputs. + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + + // TODO: we may want to merge that into the fallback code (currently called + // basic_loop_aa_vertical) + // Although this may not be needed if / when we port all this code to use + // Vec.h since this would potentially give us another fall-back implem + const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr()); + + const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr(); + const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr(); + + uint8_t* unpacked_output_p = unpacked_output.data_ptr(); + const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr(); + + auto xout = unpacked_output.size(2); + auto yout = unpacked_output.size(1); + const auto num_channels = unpacked_input.size(0); + TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0)); + + auto xout_stride = xout * num_channels; + for (const auto yy : c10::irange(yout)) { + const auto* k = &kk[yy * ksize]; + auto ids_min = idx_ptr_xmin[yy]; + auto ids_size = idx_ptr_size[yy]; + ImagingResampleVerticalConvolution8u( + unpacked_output_p + yy * xout_stride, + unpacked_input_p, + xout, + ids_min, + ids_size, + k, + vert_weights_precision, + num_channels); + } +} + +// This is the only public entry point in this file. It supports bilinear or bicubic +// mode for uint8 dtype when C <= 4, with or without antialias. The +// implem is based on PIL-SIMD. +// Its equivalent implementation (fallback) for when AVX isn't supported or when +// C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of +// future improvement that can be done: look for the TODOs in this file. +// For details on how the weights are computed and how the multiplications are +// run on int (instead of float weights), see +// [ Weights computation for uint8_t and multiplication trick ] +// For details on how the AVX kernels are implemented, see +// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5 +// See also [ Support for antialias=False as a subcase of antialias=True ] to +// learn more about how the antialias=False case is computed. The same holds +// here: all these kernels are general enough to handle an arbitrary number of +// weights, but when aa=False they could be optimized further. +template +void upsample_avx_bilinear_bicubic_uint8( + const at::Tensor& input_, + const at::Tensor& output, + bool align_corners, + const scale_type& scales, + bool antialias) { + auto batch_size = input_.size(0); + auto num_channels = input_.size(1); + auto xin = input_.size(3); + auto yin = input_.size(2); + auto xout = output.size(3); + auto yout = output.size(2); + + if (xin == xout && yin == yout) { + output.copy_(input_); + return; + } + + at::Tensor input = input_; + if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) { + // If input is not contiguous with memory format channels first or channels last, + // we explicitly convert the input to contiguous channels last memory format. + // This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last, + // Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may + // have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input + // directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex. + input = input.contiguous(at::MemoryFormat::ChannelsLast); + } + + auto need_horizontal = xout != xin; + auto need_vertical = yout != yin; + + int ksize_horiz, ksize_vert; + std::vector horiz_indices_weights, vert_indices_weights; + unsigned int horiz_weights_precision, vert_weights_precision; + + bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast); + bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast); + + if (need_horizontal) { + int interp_dim = 3; + auto stride = (skip_unpacking) ? num_channels : 4; + std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) = + F::compute_index_ranges_int16_weights( + /*input_size=*/xin, + /*output_size=*/xout, + /*stride=*/stride, + /*ndims=*/4, + /*reshape_dim=*/interp_dim, + /*align_corners=*/align_corners, + /*opt_scale=*/scales[interp_dim - 2], + /*antialias=*/antialias, + /*align_i32=*/true); + } + + if (need_vertical) { + int interp_dim = 2; + auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout; + std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) = + F::compute_index_ranges_int16_weights( + /*input_size=*/yin, + /*output_size=*/yout, + /*stride=*/stride, + /*ndims=*/4, + /*reshape_dim=*/interp_dim, + /*align_corners=*/align_corners, + /*opt_scale=*/scales[interp_dim - 2], + /*antialias=*/antialias, + /*align_i32=*/true); + } + + at::Tensor buffer_horiz, buffer_vert; + // Minor optimization: we can avoid allocating an extra buffer if we're performing + // horizontal-only or vertical-only interpolation, and if the tensor doesn't + // need repacking + if (need_horizontal && (need_vertical || !skip_packing)) { + auto c = (skip_unpacking) ? num_channels : 4; + buffer_horiz = at::empty({c, yin, xout}, input.options()); + } + if (need_vertical && !skip_packing) { + auto c = (skip_unpacking) ? num_channels : 4; + buffer_vert = at::empty({c, yout, xout}, input.options()); + } + + for (const auto i : c10::irange(batch_size)) { + + at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]); + at::Tensor unpacked_output; + + if (need_horizontal) { + at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i]; + + if (skip_unpacking && num_channels == 3) { + ImagingResampleHorizontal<3>( + unpacked_output_temp, + unpacked_input, + ksize_horiz, + horiz_indices_weights, + horiz_weights_precision); + } else { + ImagingResampleHorizontal<4>( + unpacked_output_temp, + unpacked_input, + ksize_horiz, + horiz_indices_weights, + horiz_weights_precision); + } + unpacked_output = unpacked_input = unpacked_output_temp; + } + if (need_vertical) { + unpacked_output = (skip_packing) ? output[i] : buffer_vert; + + ImagingResampleVertical( + unpacked_output, + unpacked_input, + ksize_vert, + vert_indices_weights, + vert_weights_precision + ); + } + + TORCH_INTERNAL_ASSERT(unpacked_output.defined()); + + if (!skip_packing) { + pack_rgb(unpacked_output, output[i]); + } + } +} + +void ImagingResampleHorizontalConvolution8u4x( + uint8_t* C10_RESTRICT lineOut0, + uint8_t* C10_RESTRICT lineOut1, + uint8_t* C10_RESTRICT lineOut2, + uint8_t* C10_RESTRICT lineOut3, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn0, + const uint8_t* C10_RESTRICT lineIn1, + const uint8_t* C10_RESTRICT lineIn2, + const uint8_t* C10_RESTRICT lineIn3, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line) { + + // Interpolation horizontal pass processing together 4 vertical lines. + // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA + // we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1. + // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values + // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1"). + + // Define shuffling masks (low/high) for num_channels 4 and 3 + // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA: + // [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] -> + // [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0] + // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:: + // [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] -> + // [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0] + + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); + const auto mask_low_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_high_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6); + + const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4; + const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4; + + const auto stride = num_channels * sizeof(uint8_t); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + // out_xsize = output width, out_x = output x index + // ids_min is the input offset index corresponding to out_x + // ids_size is the interpolation size for out_x + + // Let's precompute ids_size limits for block 4 and block 2. + // + // In block 4 (4 means we process 4 weight values together), we read input data + // with _mm_loadu_si128, i.e. 16 bytes, per one line: + // lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 16.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft + // RGBA: b4_delta = b4_delta_soft = 3 + // RGB : b4_delta = 5 + // RGB : b4_delta_soft = 4 + const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + + // In block 2 (2 means we process 2 weights values together), we read input data + // with _mm_loadl_epi64, i.e. 8 bytes, per one line: + // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 8.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft + // RGBA: b2_delta = b2_delta_soft = 1 + // RGB : b2_delta = 2 + // RGB : b2_delta_soft = 1 + const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + + const auto max_out_x_strided = out_xsize * stride; + const auto max_in_x_strided = in_xsize * stride; + + const auto zero = _mm256_setzero_si256(); + const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1)); + + for (const auto out_x : c10::irange(out_xsize)) { + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; + + auto sss0 = initial; + auto sss1 = initial; + + const auto * lineIn0_min = lineIn0 + ids_min; + const auto * lineIn1_min = lineIn1 + ids_min; + const auto * lineIn2_min = lineIn2 + ids_min; + const auto * lineIn3_min = lineIn3 + ids_min; + + // block 4 + for (; i < ids_size - b4_delta; i += 4) { + // Load 4 values from weight vector + // mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + // mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...] + const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]); + const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]); + + // RGBA: Load 8 pixels (4 per line) from input lines 0 and 1: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3 + // ] + // RGB: Load 10 pixels (5 per line) + // source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5 + // ] + auto source = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1); + + // Apply mask_low: + // RGBA: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + // RGB: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0)); + + // Apply mask_high: + // RGBA: + // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0] + // RGB: + // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0] + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1)); + + // Same as above to next lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1); + auto pix3 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0)); + auto pix4 = _mm256_shuffle_epi8(source2, mask_high); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1)); + } + + // block 2 + for (; i < ids_size - b2_delta; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); + + // Load 4 pixels (2 per line) from input lines 0 and 1: + // RGBA: source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0 + // R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))), + _mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1); + // Apply mask_low: + // RGBA: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + // RGB: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0] + auto pix1 = _mm256_shuffle_epi8(source1, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // Same as above for lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))), + _mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1); + auto pix2 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + // block 1 + const auto i32_aligned = num_channels == 4; + for (; i < ids_size - 1; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + const auto mmk = _mm256_set1_epi32(k[i]); + + // Load 2 pixels (one per line) from input lines 0 and 1: + // RGBA: pix1 = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0 + // R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0 + // ] + // RGB: pix1 = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0 + // R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0 + // ] + auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // Same as above for lines 2 and 3 + auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + if (i == ids_size - 1) { + // last element + auto mmk = _mm256_set1_epi32(k[i]); + // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes + // lines 0, 1 and 2 wont go out of allocated memory bounds + auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1); + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk)); + + auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned); + __m128i p1; + if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) { + uint8_t input[4]; + std::memcpy(input, lineIn3_min + stride * i, 3); + p1 = mm_cvtepu8_epi32(input, true); + } else { + p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned); + } + auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss0 = _mm256_srai_epi32(sss0, coefs_precision); + sss1 = _mm256_srai_epi32(sss1, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) + sss0 = _mm256_packs_epi32(sss0, zero); + sss1 = _mm256_packs_epi32(sss1, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) + sss0 = _mm256_packus_epi16(sss0, zero); + sss1 = _mm256_packus_epi16(sss1, zero); + + // Write the output into single uint32 + // (a b c d) -> x_uint32 + auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0)); + auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1)); + auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1)); + auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1)); + + const auto out_x_strided = stride * out_x; + + if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) { + // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write + // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1). + // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct + // value which was previously computed by another line. In other words, it means that we can not overwrite + // it by simply writing 4 bytes from the register to the output. We'll do the following: + // v----------| + // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...] + // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1) + // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1) + // Output = [... R G B | R1 G1 B1 R2 ...] + + _write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0); + _write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1); + _write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2); + + if (C10_UNLIKELY(is_last_line)) { + // When we handle the last line, we can not access the next 4 bytes + // as they are out of memory bounds. + std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels); + } else { + _write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3); + } + } else if (num_channels == 3) { + // Memcpy 4-bytes is faster than 3-bytes and here + // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value + // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...) + std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4); + std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4); + std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4); + std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4); + } else { + // num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned + *(uint32_t *)(lineOut0 + out_x_strided) = o0; + *(uint32_t *)(lineOut1 + out_x_strided) = o1; + *(uint32_t *)(lineOut2 + out_x_strided) = o2; + *(uint32_t *)(lineOut3 + out_x_strided) = o3; + } + } +} + +void ImagingResampleHorizontalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line) { + + // Interpolation horizontal pass processing only one vertical line. + // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA + // we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1 + // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in + // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1"). + + // Define various shuffling masks + const auto kmask_low = _mm256_set_epi8( + 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + const auto kmask_high = _mm256_set_epi8( + 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4); + const auto kmask_hl = _mm256_set_epi8( + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); + const auto mask_low_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_high_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6); + const auto mask_hl_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_hl_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + + const auto mask_low128_c3 = _mm_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_low128_c4 = _mm_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + + const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4; + const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4; + const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4; + const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4; + + // out_xsize = output width, out_x = output x index + // ids_min is the input offset index corresponding to out_x + // ids_size is the interpolation size for out_x + + const auto stride = num_channels * sizeof(uint8_t); + const auto zero = _mm_setzero_si128(); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + // Let's precompute ids_size limits for block 8, block 4 and block 2 + // + // In block 8 (8 means we process 8 weight values together), we read at + // most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB) + // lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min) + // --> i <= ids_size - 32.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft + // RGBA: b8_delta = b8_delta_soft = 7 + // RGB : b8_delta = 10 + // RGB : b8_delta_soft = 9 + const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9); + + // In block 4 (4 means we process 4 weight values together), we read + // 16 bytes of input data. + // lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 16.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft + // RGBA: b4_delta = b4_delta_soft = 3 + // RGB : b4_delta = 5 + // RGB : b4_delta_soft = 4 + const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + + // In block 2 (2 means we process 2 weight values together), we read + // 8 bytes of input data. + // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 8.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft + // RGBA: b2_delta = b2_delta_soft = 1 + // RGB : b2_delta = 2 + // RGB : b2_delta_soft = 1 + const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + + const auto max_out_x_strided = out_xsize * stride; + const auto max_in_x_strided = in_xsize * stride; + + for (const auto out_x : c10::irange(out_xsize)) { + __m128i sss; + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; + + const auto * lineIn_min = lineIn + ids_min; + + if (ids_size < 8) { + sss = _mm_set1_epi32(1 << (coefs_precision - 1)); + } else { + // Lower part will be added to higher, use only half of the error + auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2)); + + // block 8 + for (; i < ids_size - b8_delta; i += 8) { + // Load 8 values from weight vector + auto tmp = _mm_loadu_si128((__m128i*)&k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // RGBA: Load 8 pixels from input: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + // RGB: Load 10 pixels from input (however we can process only 8 pixels): + // source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9 + // ] + auto source = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1); + + // Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA + // RGBA: pix1 = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0 + // ] + // RGB: pix1 = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 + // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0 + // ] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // mmk1 = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ... + // ] + auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w4 * C4 + w5 * C5 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1)); + + // Same as above for higher part of each lane + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high); + // Compute output value as + // C += w2 * C2 + w3 * C3 + // C += w6 * C6 + w7 * C7 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2)); + } + + // block 4 + for (; i < ids_size - b4_delta; i += 4) { + // Load 4 values from weight vector + auto tmp = _mm_loadl_epi64((__m128i *) &k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // Load pixels from input line + tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i)); + // RGBA: source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // ] + // RGB: source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // ] + auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + // RGBA: pix = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 + // ] + // RGB: pix = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 + // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 + // ] + auto pix = _mm256_shuffle_epi8(source, mask_hl); + // mmk = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ... + // ] + auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w2 * C2 + w3 * C3 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk)); + } + + // Sum results between the lanes + sss = _mm_add_epi32( + _mm256_extracti128_si256(sss256, 0), + _mm256_extracti128_si256(sss256, 1)); + } + + // block 2 + for (; i < ids_size - b2_delta; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + // Load pixels from input line + // RGBA: source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + // RGB: source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0 + // ] + auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i)); + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + auto pix = _mm_shuffle_epi8(source, mask_low128); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // block 1 + const auto i32_aligned = num_channels == 4; + for (; i < ids_size - 1; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + auto mmk = _mm_set1_epi32(k[i]); + // Load one pixel from input line + // RGBA: pix = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0 + // ] + // RGB: pix = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0 + // ] + auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + if (i == ids_size - 1) { + // last element + auto mmk = _mm_set1_epi32(k[i]); + __m128i pix; + auto p = lineIn_min + stride * i; + if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) { + uint8_t input[4]; + std::memcpy(input, p, 3); + pix = mm_cvtepu8_epi32(input, true); + } else { + pix = mm_cvtepu8_epi32(p, i32_aligned); + } + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss = _mm_srai_epi32(sss, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) + sss = _mm_packus_epi16(sss, zero); + // Write the output into single uint32 + // (a b c d) -> x_uint32 + auto o = _mm_cvtsi128_si32(sss); + const auto out_x_strided = stride * out_x; + if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) { + if (C10_UNLIKELY(is_last_line)) { + // When we handle the last line, we can not access the next 4 bytes + // as they are out of memory bounds. + std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3); + } else { + // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write + // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1). + // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct + // value which was previously computed by another line. In other words, it means that we can not overwrite + // it by simply writing 4 bytes from the register to the output. We'll do the following: + // v----------| + // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...] + // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1) + // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1) + // Output = [... R G B | R1 G1 B1 R2 ...] + _write_endline_rgb_as_uint32(lineOut + out_x_strided, o); + } + } else if (num_channels == 3) { + // Memcpy 4-bytes is faster than 3-bytes and here + // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value + // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...) + std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4); + } else { + // num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned + *(uint32_t *)(lineOut + out_x_strided) = o; + } + } +} + +void ImagingResampleVerticalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + const uint8_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision, + int64_t num_channels) { + + // Interpolation vertical pass processing one line. + // - We process x-axis data with blocks of 8, 2 and 1 + // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m. + + // xsize = output width, also equals to input width + // ids_size = interpolation size + // ids_min = input y start index + const auto stride = num_channels * sizeof(uint8_t); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + const int64_t data_size = xsize * stride; + const int64_t data_stride = stride; + constexpr auto vec_size = 256 / 8; + + const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1)); + const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1)); + const auto zero = _mm_setzero_si128(); + const auto zero_256 = _mm256_setzero_si256(); + + int64_t j = 0; + // block 8 + const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride; + for (; j < data_size - vec_size; j += b8_usable_vec_stride) { + auto sss0 = initial_256; + auto sss1 = initial_256; + auto sss2 = initial_256; + auto sss3 = initial_256; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); + + // RGBA: Load 8 pixels per line + // source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + // RGB: Load 10 pixels per line (however we can process only 8 pixels): + // source1 = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9 + // ] + auto source1 = + _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i)); + auto source2 = + _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1))); + + // Interleave source1 and source2 from the low half of each 128-bit lane + // and cast the result to epi16 + // RGBA: pix1 = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + // RGB: pix1 = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0 + // ] + auto source_lo = _mm256_unpacklo_epi8(source1, source2); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c0 + w1 * C0 + // C += w0 * c1 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // RGBA: pix2 = [ + // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0 + // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0 + // ] + // RGB: pix2 = [ + // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0 + // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0 + // ] + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c2 + w1 * C2 + // C += w0 * c3 + w1 * C3 for each channel in 32-bit precision + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + + // Same as above for the high half of each 128-bit lane + auto source_hi = _mm256_unpackhi_epi8(source1, source2); + auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); + } + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm256_set1_epi32(k[i]); + + auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size)); + + auto source_lo = _mm256_unpacklo_epi8(source1, zero_256); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + + auto source_hi = _mm256_unpackhi_epi8(source1, zero_256); + auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256()); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256()); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); + } + // Convert fixed point values back to integers (truncating) + sss0 = _mm256_srai_epi32(sss0, coefs_precision); + sss1 = _mm256_srai_epi32(sss1, coefs_precision); + sss2 = _mm256_srai_epi32(sss2, coefs_precision); + sss3 = _mm256_srai_epi32(sss3, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss0 = _mm256_packs_epi32(sss0, sss1); + sss2 = _mm256_packs_epi32(sss2, sss3); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss0 = _mm256_packus_epi16(sss0, sss2); + + // Stores 32 bytes + _mm256_storeu_si256((__m256i*)(lineOut + j), sss0); + } + + // TODO: Do we also need block 4 ??? + // block 2 + const auto b2_usable_vec_stride = (8 / data_stride) * data_stride; + for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) { + auto sss0 = initial; + auto sss1 = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load 2 pixels per line + // RGBA: source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size)); + auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size)); + // Interleave source1 and source2 and cast the result to epi16 + // RGBA: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + // RGB: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision + sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk)); + // RGBA: pix = [ + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + // RGB: pix = [ + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0 + // ] + pix = _mm_unpackhi_epi8(source, zero); + // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision + sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk)); + } + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm_set1_epi32(k[i]); + + auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size)); + + auto source = _mm_unpacklo_epi8(source1, zero); + auto pix1 = _mm_unpacklo_epi8(source, zero); + sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk)); + auto pix2 = _mm_unpackhi_epi8(source, zero); + sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk)); + } + // Convert fixed point values back to integers (truncating) + sss0 = _mm_srai_epi32(sss0, coefs_precision); + sss1 = _mm_srai_epi32(sss1, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss0 = _mm_packs_epi32(sss0, sss1); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss0 = _mm_packus_epi16(sss0, sss0); + // Store 2 pixels to the output + _mm_storel_epi64((__m128i*)(lineOut + j), sss0); + } + + // block 1 + const auto b1_usable_vec_stride = (4 / data_stride) * data_stride; + const auto i32_aligned = num_channels == 4; + for (; j < data_size - 4; j += b1_usable_vec_stride) { + auto sss = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load one pixel per line + // RGBA: source1 = [ + // r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned); + auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned); + + // Interleave source1 and source2 and cast the result to epi16 + // RGBA: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + // RGB: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + for (; i < ids_size; i++) { + auto mmk = _mm_set1_epi32(k[i]); + auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned); + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + sss = _mm_srai_epi32(sss, coefs_precision); + sss = _mm_packs_epi32(sss, zero); + sss = _mm_packus_epi16(sss, zero); + + auto o = _mm_cvtsi128_si32(sss); + + // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3 + // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data. + // We also wont go out of bounds of lineOut memory allocation + std::memcpy(lineOut + j, (uint8_t *) &o, 4); + } + + for (; j < data_size; j += data_stride) { + auto sss = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + // For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary + // for the last remaining line + for (; i < ids_size - 2; i += 2) { + // Load two coefficients at once + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load 2 lines + auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned); + auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned); + + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Same processing as above but with a single weight value + for (; i < ids_size; i++) { + auto mmk = _mm_set1_epi32(k[i]); + + const uint8_t * p = lineIn_min + i * data_size; + __m128i pix; + // There is no much perf gain using more detailed condition like + // num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size + // const int64_t in_max_size = data_size * in_ysize; + if (num_channels == 3) { + uint8_t input[4]; + std::memcpy(input, p, 3); + pix = mm_cvtepu8_epi32(input, true); + } else { + pix = mm_cvtepu8_epi32(p, true); + } + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss = _mm_srai_epi32(sss, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss = _mm_packus_epi16(sss, zero); + // Store one pixel to the output + auto o = _mm_cvtsi128_si32(sss); + if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) { + std::memcpy(lineOut + j, (uint8_t *) &o, 3); + } else { + std::memcpy(lineOut + j, (uint8_t *) &o, 4); + } + } +} + +} // anonymous namespace +#endif // CPU_CAPABILITY_AVX2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..efcaf4d1c7aa1df4cff5f78faa35d63c7831236c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using weight_norm_fn = void(*)( + TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t); +using weight_norm_backward_fn = void(*)( + TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, + const TensorBase&, const TensorBase&, int64_t); + +DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub) +DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h new file mode 100644 index 0000000000000000000000000000000000000000..f4fd3b7bc461fbf82e8b4a16dd9453e46e124efa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h @@ -0,0 +1,522 @@ +#pragma once +/* + AVX implementation of sin, cos, sincos, exp and log + + Based on "sse_mathfun.h", by Julien Pommier + http://gruntthepeon.free.fr/ssemath/ + + Copyright (C) 2012 Giovanni Garberoglio + Interdisciplinary Laboratory for Computational Science (LISC) + Fondazione Bruno Kessler and University of Trento + via Sommarive, 18 + I-38123 Trento (Italy) + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#include + +/* The original source of this file has been modified. */ +#if defined(CPU_CAPABILITY_AVX2) + +#if defined(__GNUC__) +# define ALIGN32_BEG __attribute__((aligned(32))) +#elif defined(_WIN32) +# define ALIGN32_BEG __declspec(align(32)) +#endif + +typedef __m256 v8sf; // vector of 8 float (avx2) +typedef __m256i v8si; // vector of 8 int (avx2) + +/* declare some AVX constants -- why can't I figure a better way to do that? */ +#define _PS256_CONST(Name, Val) \ + static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PI32_CONST256(Name, Val) \ + static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PS256_CONST_TYPE(Name, Type, Val) \ + static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } + +_PS256_CONST(1 , 1.0f); +_PS256_CONST(0p5, 0.5f); +/* the smallest non denormalized float number */ +_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000); +_PS256_CONST_TYPE(mant_mask, int, 0x7f800000); +_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000); + +_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000); +_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000); + +_PI32_CONST256(0, 0); +_PI32_CONST256(1, 1); +_PI32_CONST256(inv1, ~1); +_PI32_CONST256(2, 2); +_PI32_CONST256(4, 4); +_PI32_CONST256(0x7f, 0x7f); + +_PS256_CONST(cephes_SQRTHF, 0.707106781186547524); +_PS256_CONST(cephes_log_p0, 7.0376836292E-2); +_PS256_CONST(cephes_log_p1, - 1.1514610310E-1); +_PS256_CONST(cephes_log_p2, 1.1676998740E-1); +_PS256_CONST(cephes_log_p3, - 1.2420140846E-1); +_PS256_CONST(cephes_log_p4, + 1.4249322787E-1); +_PS256_CONST(cephes_log_p5, - 1.6668057665E-1); +_PS256_CONST(cephes_log_p6, + 2.0000714765E-1); +_PS256_CONST(cephes_log_p7, - 2.4999993993E-1); +_PS256_CONST(cephes_log_p8, + 3.3333331174E-1); +_PS256_CONST(cephes_log_q1, -2.12194440e-4); +_PS256_CONST(cephes_log_q2, 0.693359375); + + +/* natural logarithm computed for 8 simultaneous float + return NaN for x <= 0 +*/ +inline v8sf log256_ps(v8sf x) { + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps()); + v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS); + + x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */ + + // can be done with AVX2 + imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23); + + /* keep only the fractional part */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask); + x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5); + + // this is again another AVX2 instruction + imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f); + v8sf e = _mm256_cvtepi32_ps(imm0); + + e = _mm256_add_ps(e, one); + + /* part2: + if( x < SQRTHF ) { + e -= 1; + x = x + x - 1.0; + } else { x = x - 1.0; } + */ + //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF); + v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS); + v8sf tmp = _mm256_and_ps(x, mask); + x = _mm256_sub_ps(x, one); + e = _mm256_sub_ps(e, _mm256_and_ps(one, mask)); + x = _mm256_add_ps(x, tmp); + + v8sf z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_log_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8); + y = _mm256_mul_ps(y, x); + + y = _mm256_mul_ps(y, z); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1); + y = _mm256_add_ps(y, tmp); + + + tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5); + y = _mm256_sub_ps(y, tmp); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2); + x = _mm256_add_ps(x, y); + x = _mm256_add_ps(x, tmp); + x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN + return x; +} + +_PS256_CONST(exp_hi, 88.3762626647949f); +_PS256_CONST(exp_lo, -88.3762626647949f); + +_PS256_CONST(cephes_LOG2EF, 1.44269504088896341); +_PS256_CONST(cephes_exp_C1, 0.693359375); +_PS256_CONST(cephes_exp_C2, -2.12194440e-4); + +_PS256_CONST(cephes_exp_p0, 1.9875691500E-4); +_PS256_CONST(cephes_exp_p1, 1.3981999507E-3); +_PS256_CONST(cephes_exp_p2, 8.3334519073E-3); +_PS256_CONST(cephes_exp_p3, 4.1665795894E-2); +_PS256_CONST(cephes_exp_p4, 1.6666665459E-1); +_PS256_CONST(cephes_exp_p5, 5.0000001201E-1); + +inline v8sf exp256_ps(v8sf x) { + v8sf tmp = _mm256_setzero_ps(), fx; + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi); + x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF); + fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5); + + /* how to perform a floorf with SSE: just below */ + //imm0 = _mm256_cvttps_epi32(fx); + //tmp = _mm256_cvtepi32_ps(imm0); + + tmp = _mm256_floor_ps(fx); + + /* if greater, subtract 1 */ + //v8sf mask = _mm256_cmpgt_ps(tmp, fx); + v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); + mask = _mm256_and_ps(mask, one); + fx = _mm256_sub_ps(tmp, mask); + + tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1); + v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2); + x = _mm256_sub_ps(x, tmp); + x = _mm256_sub_ps(x, z); + + z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_exp_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5); + y = _mm256_mul_ps(y, z); + y = _mm256_add_ps(y, x); + y = _mm256_add_ps(y, one); + + /* build 2^n */ + imm0 = _mm256_cvttps_epi32(fx); + // another two AVX2 instructions + imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f); + imm0 = _mm256_slli_epi32(imm0, 23); + v8sf pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} + +_PS256_CONST(minus_cephes_DP1, -0.78515625); +_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4); +_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8); +_PS256_CONST(sincof_p0, -1.9515295891E-4); +_PS256_CONST(sincof_p1, 8.3321608736E-3); +_PS256_CONST(sincof_p2, -1.6666654611E-1); +_PS256_CONST(coscof_p0, 2.443315711809948E-005); +_PS256_CONST(coscof_p1, -1.388731625493765E-003); +_PS256_CONST(coscof_p2, 4.166664568298827E-002); +_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI + + +/* evaluation of 8 sines at onces using AVX intrinsics + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + +*/ +inline v8sf sin256_ps(v8sf x) { // any x + v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; + v8si imm0, imm2; + + sign_bit = x; + /* take the absolute value */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); + /* extract the sign bit (upper one) */ + sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask); + + /* scale by 4/Pi */ + y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); + + /* + Here we start a series of integer operations, which are in the + realm of AVX2. + If we don't have AVX, let's perform them using SSE2 directives + */ + + /* store the integer part of y in mm0 */ + imm2 = _mm256_cvttps_epi32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + // another two AVX2 instruction + imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1); + imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1); + y = _mm256_cvtepi32_ps(imm2); + + /* get the swap sign flag */ + imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4); + imm0 = _mm256_slli_epi32(imm0, 29); + /* get the polynom selection mask + there is one polynom for 0 <= x <= Pi/4 + and another one for Pi/4 +#include + +namespace at::native { + +using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&); +using int4pack_mm_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&); +using int8pack_mm_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); +using dyn_quant_pack_4bit_weight_fn = void (*)( + Tensor&, + const Tensor&, + const Tensor&, + const std::optional& bias, + const int64_t, + const int64_t, + const int64_t); +using dyn_quant_matmul_4bit_fn = void (*)( + const Tensor&, + const Tensor&, + const Tensor&, + const int64_t, + const int64_t, + const int64_t, + const int64_t); + +DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub) +DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub) +DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub) +DECLARE_DISPATCH( + dyn_quant_pack_4bit_weight_fn, + dyn_quant_pack_4bit_weight_stub) +DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h new file mode 100644 index 0000000000000000000000000000000000000000..13244af3b34a0f36defc69fa7fc219e93aae7757 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace at::native { + +inline ScalarType first_type() { + return ScalarType::Undefined; +} + +template +inline ScalarType first_type(const Tensor& arg, const Args&... parameters) { + return arg.defined() ? arg.scalar_type() : first_type(parameters...); +} + +template +inline bool is_mixed_type(const Tensor& input, const Args&... parameters) { + const auto parameter_type = first_type(parameters...); + return ((parameter_type != ScalarType::Undefined) && + (parameter_type != input.scalar_type())); +} + +// currently on CPU, mixed data type is only supported +// when input is 'BFloat16' or 'Half' and parameters are 'Float' +inline void check_mixed_data_type(const Tensor& input) { + TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()), + "mixed dtype (CPU): all inputs must share same datatype."); +} + +template +inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) { + TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float, + "mixed dtype (CPU): expect parameter to have scalar type of Float"); + check_mixed_data_type(input, parameters...); +} + +inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) { + return is_mixed_type ? ScalarType::Float : t.scalar_type(); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/moments_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/moments_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6f403d60ea7c09849e1ea9d48625532aa51117c3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/moments_utils.h @@ -0,0 +1,202 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +template using opmath_t = at::opmath_type; + +constexpr int64_t kChunkSize = 16; + +template +void AddMoments( + int64_t m0_add, + const T& m1_add, + const T& m2_add, + int64_t& m0, + T& m1, + T& m2) { + const int64_t n = m0 + m0_add; + const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); + const T delta = m1_add - m1; + m1 += c * delta; + m2 += m2_add + delta * delta * c * static_cast(m0); + m0 = n; +} + +template +C10_ALWAYS_INLINE void AddMomentsVec( + int64_t m0_add, + const vec::Vectorized& m1_add, + const vec::Vectorized& m2_add, + int64_t& m0, + vec::Vectorized& m1, + vec::Vectorized& m2) { + using Vec = vec::Vectorized; + const int64_t n = m0 + m0_add; + const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); + const Vec c_vec(c); + const Vec delta = m1_add - m1; + m1 += c_vec * delta; + m2 += m2_add + delta * delta * c_vec * Vec(static_cast(m0)); + m0 = n; +} + +template +inline std::enable_if_t>, void> +UpdateMomentsVec( + int64_t m0, + const T* X_ptr, + const std::array>, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized>& m1_stk0, + vec::Vectorized>& m2_stk0) { + using Vec = vec::Vectorized>; + Vec m1_vec(0); + Vec m2_vec(0); + for (const auto j : c10::irange(m0)) { + const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size()); + const Vec delta_vec = x_vec - m1_vec; + m1_vec += delta_vec * c_vecs[j]; + m2_vec += delta_vec * (x_vec - m1_vec); + } + AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0); +} + +// each bfloat16/half vector will be converted to two float vectors, +// and accumulated successively on m1_stk0/m2_stk0. +template +inline std::enable_if_t>, void> +UpdateMomentsVec( + int64_t m0, + const T* X_ptr, + const std::array>, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized>& m1_stk0, + vec::Vectorized>& m2_stk0) { + using Vec = vec::Vectorized; + using fVec = vec::Vectorized>; + fVec m1_fvec0(0), m1_fvec1(0); + fVec m2_fvec0(0), m2_fvec1(0); + for (const auto j : c10::irange(m0)) { + const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size()); + auto [x_fvec0, x_fvec1] = convert_to_float(x_bvec); + const fVec delta_fvec0 = x_fvec0 - m1_fvec0; + const fVec delta_fvec1 = x_fvec1 - m1_fvec1; + m1_fvec0 += delta_fvec0 * c_vecs[j]; + m1_fvec1 += delta_fvec1 * c_vecs[j]; + m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0); + m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1); + } + AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0); + AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0); +} + +// Compute rowwise moments by Welford algorithm and cascade sum to improve +// numerical stability. +// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +// https://en.wikipedia.org/wiki/Pairwise_summation +template +std::pair, opmath_t> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) { + using math_t = opmath_t; + + constexpr int64_t kVecSize = vec::Vectorized::size(); + constexpr int64_t kAccVecSize = vec::Vectorized::size(); + const int64_t n = N / kVecSize; + const int64_t m = divup(n, kChunkSize); + const int64_t depth = utils::CeilLog2(m); + + using Vec = vec::Vectorized; + const Vec kZeroVec(math_t(0)); + c10::SmallVector m0_stk(depth, 0); + c10::SmallVector m1_stk(depth, kZeroVec); + c10::SmallVector m2_stk(depth, kZeroVec); + + for (const auto i : c10::irange(m)) { + const T* X_ptr = X + i * kChunkSize * kVecSize; + const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize); + static std::array c_vecs = ([]() { + std::array result; + for (const auto i : c10::irange(kChunkSize)) { + result[i] = Vec(math_t(1) / static_cast(i + 1)); + } + return result; + })(); + UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]); + + int64_t mask = i + 1; + for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) { + AddMomentsVec( + m0_stk[j - 1], + m1_stk[j - 1], + m2_stk[j - 1], + m0_stk[j], + m1_stk[j], + m2_stk[j]); + m0_stk[j - 1] = 0; + m1_stk[j - 1] = kZeroVec; + m2_stk[j - 1] = kZeroVec; + mask >>= 1; + } + } + for (const auto i : c10::irange(1, depth)) { + AddMomentsVec( + m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]); + } + + std::array m1_arr{}; + std::array m2_arr{}; + m1_stk[0].store(m1_arr.data()); + m2_stk[0].store(m2_arr.data()); + + int64_t m0 = 0; + math_t m1 = 0; + math_t m2 = 0; + for (int64_t i = n * kVecSize; i < N; ++i) { + math_t x = static_cast(X[i]); + const math_t delta = x - m1; + ++m0; + m1 += delta / static_cast(m0); + m2 += delta * (x - m1); + } + // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result + int64_t m0_add = n * kVecSize / kAccVecSize; + for (const auto i : c10::irange(kAccVecSize)) { + AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2); + } + + return std::make_pair(m1, m2 / static_cast(N - ddof)); +} + +template +std::pair, opmath_t> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) { + using Vec = vec::Vectorized; + constexpr int64_t kVecSize = Vec::size(); + const int64_t n = N / kVecSize; + const int64_t m = divup(n, kChunkSize); + const int64_t depth = utils::CeilLog2(m); + if (depth <= 4) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 8) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 16) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 32) { + return RowwiseMomentsImpl(X, N, ddof); + } else { + return RowwiseMomentsImpl(X, N, ddof); + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3d4c39afe3c75c8887c34142fcefe63e7dc940cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/utils.h @@ -0,0 +1,212 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef USE_FBGEMM +#include +#endif + +namespace at::native { + +template +inline void _store(T* dst, at::vec::Vectorized src) { + src.store(dst); +} + +inline void _store(at::BFloat16* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_bfloat16(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline void _store(at::Half* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_half(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline namespace CPU_CAPABILITY { + +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// Helper struct for bfloat16/float16 vectorization +// Useful when you need float as immediate dtype or accumulate dtype +using namespace vec; +struct Vec2 { + Vectorized val0, val1; + Vec2(Vectorized v0, Vectorized v1) : val0(v0), val1(v1) {} + Vec2(float v) : val0(v), val1(v) {} + static Vec2 loadu(const BFloat16* ptr) { + auto [v0, v1] = convert_bfloat16_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } + static Vec2 loadu(const Half* ptr) { + auto [v0, v1] = convert_half_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } + static Vec2 loadu(const float* ptr) { + return {Vectorized::loadu(ptr), Vectorized::loadu(ptr + Vectorized::size())}; + } + void store(BFloat16* ptr) const { + Vectorized val = convert_float_bfloat16(val0, val1); + val.store(ptr); + } + void store(Half* ptr) const { + Vectorized val = convert_float_half(val0, val1); + val.store(ptr); + } + void store(float* ptr) const { + val0.store(ptr); + val1.store(ptr + Vectorized::size()); + } +}; +inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; } +inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; } +inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; } +inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; } +inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; } +inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; } + +template struct VectorizedType { using type = Vectorized; }; +template <> struct VectorizedType { using type = Vec2; }; +template <> struct VectorizedType { using type = Vec2; }; +template using VecType = typename VectorizedType::type; + +// Helper for mixed data type parameter Vec::load +inline std::tuple, Vectorized> load2f(const BFloat16* ptr) { + return convert_bfloat16_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr) { + return convert_half_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr) { + using Vec = Vectorized; + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size())); +} + +inline std::tuple, Vectorized> load2f(const BFloat16* ptr, int64_t count) { + return convert_bfloat16_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr, int64_t count) { + return convert_half_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr, int64_t count) { + using Vec = Vectorized; + if (count > Vec::size()) { + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size())); + } else { + return std::make_tuple(Vec::loadu(ptr, count), Vec(0)); + } +} + +} // namespace + +namespace utils { + +template +T CeilLog2(const T& x) { + if (x <= 2) { + return 1; + } + // Last set bit is floor(log2(x)), floor + 1 is ceil + // except when x is an exact powers of 2, so subtract 1 first + return static_cast(llvm::findLastSet(static_cast(x) - 1)) + 1; +} + +// matrix transpose: +// src has shape of M by N, with leading dimension of ld_src +// dst has shape of N by M, with leading dimension of ld_dst +template +inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { + for (int64_t j = 0; j < N; j++) { + for (int64_t i = 0; i < M; i++) { + dst[j * ld_dst + i] = c10::load(&(src[i * ld_src + j])); + } + } +} + +#ifdef USE_FBGEMM +template <> +inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} + +template <> +inline void transpose(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} +#endif + +template +inline void parallel_sparse_csr( + const TensorAccessor& crow_acc, + const int64_t M, + const int64_t nnz, + const F& f) { + TORCH_CHECK(crow_acc.size(0) == M + 1); + + // directly parallel on `M` may lead to load imbalance, + // statically determine thread partition here to average payload + // for each thread. + int num_threads = at::get_num_threads(); + std::vector thread_splits(num_threads + 1, M); + + int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads)); + + thread_splits[0] = 0; + int64_t sum = 0; + int64_t t = 1; + for (const auto m : c10::irange(M)) { + int64_t row_start = crow_acc[m]; + int64_t row_end = crow_acc[m + 1]; + sum += row_end - row_start; + if (sum > t * thread_averge_payload) { + thread_splits[t] = m; + t++; + } + } + // need to restore the last index, + // due to rounding error when calculating `thread_averge_payload`. + thread_splits[num_threads] = M; + + at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) { + int tid = at::get_thread_num(); + int64_t begin = thread_splits[tid]; + int64_t end = thread_splits[tid + 1]; + f(begin, end); + }); +} + +} // namespace utils + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/zmath.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/zmath.h new file mode 100644 index 0000000000000000000000000000000000000000..2b4f44db085c997f9fdadd49eb078f3cc67a36f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/zmath.h @@ -0,0 +1,250 @@ +#pragma once + +// Complex number math operations that act as no-ops for other dtypes. +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +template +inline VALUE_TYPE zabs (SCALAR_TYPE z) { + return z; +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline float zabs , float> (c10::complex z) { + return std::abs(z); +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline double zabs , double> (c10::complex z) { + return std::abs(z); +} + +// This overload corresponds to non-complex dtypes. +// The function is consistent with its NumPy equivalent +// for non-complex dtypes where `pi` is returned for +// negative real numbers and `0` is returned for 0 or positive +// real numbers. +// Note: `nan` is propagated. +template +inline VALUE_TYPE angle_impl (SCALAR_TYPE z) { + if (at::_isnan(z)) { + return z; + } + return z < 0 ? c10::pi : 0; +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline float angle_impl , float> (c10::complex z) { + return std::arg(z); +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline double angle_impl , double> (c10::complex z) { + return std::arg(z); +} + +template +constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) { + return z; //No-Op +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr float real_impl , float> (c10::complex z) { + return z.real(); +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr double real_impl , double> (c10::complex z) { + return z.real(); +} + +template +constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) { + return 0; +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr float imag_impl , float> (c10::complex z) { + return z.imag(); +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr double imag_impl , double> (c10::complex z) { + return z.imag(); +} + +template +inline TYPE conj_impl (TYPE z) { + return z; //No-Op +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex{z.real(), -z.imag()}; +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template +inline TYPE ceil_impl (TYPE z) { + return std::ceil(z); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template +inline c10::complex sgn_impl (c10::complex z) { + if (z == c10::complex(0, 0)) { + return c10::complex(0, 0); + } else { + return z / zabs(z); + } +} + +template +inline TYPE floor_impl (TYPE z) { + return std::floor(z); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template +inline TYPE round_impl (TYPE z) { + return std::nearbyint(z); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template +inline TYPE trunc_impl (TYPE z) { + return std::trunc(z); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::max(a, b); + } +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) > std::abs(b) ? a : b; + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::min(a, b); + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) < std::abs(b) ? a : b; + } +} + +} // end namespace +} //end at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Activation.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..37425a166de31e206ab2b0d4d5c0188c6e92349e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Activation.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at::native { + +void launch_glu_backward_kernel(const TensorIteratorBase& iter, + int64_t gI_stride, int64_t I_stride); + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); + +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h new file mode 100644 index 0000000000000000000000000000000000000000..8efb8f98b1220e52077ab2d4bbcef62d7d91b3c2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h @@ -0,0 +1,44 @@ +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native::binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_cuda(TensorIteratorBase& iter); +void div_trunc_kernel_cuda(TensorIteratorBase& iter); +} // namespace at::native::binary_internal diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c4c3af83ccd80739b293a307da686bd78eabf35d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh @@ -0,0 +1,327 @@ +#pragma once +#include + +// Jiterator functions are guarded behind this macro +#if AT_USE_JITERATOR() + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace at::native { + +template +// warning : unused parameter when tuple is empty. +constexpr auto tuple_to_array_helper(const Tuple& t [[maybe_unused]], std::index_sequence seq) { + constexpr auto size = seq.size(); + return std::array{static_cast(&std::get(t))...}; +} + +// Helper function convert tuple to std::array +// for passing the arguments to CUDA Kernel +// NOTE: We capture tuple by reference, +// so the pointers in returned array are only valid +// till tuple is alive. +template +constexpr auto tuple_to_array(const std::tuple& extra_args) { + constexpr auto tuple_size = sizeof...(Args); + return tuple_to_array_helper(extra_args, std::make_index_sequence{}); +} + +struct JittedVecKernelCache { + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + at::cuda::jit::NvrtcFunction vec1; + at::cuda::jit::NvrtcFunction vec2; + at::cuda::jit::NvrtcFunction vec4; + at::cuda::jit::NvrtcFunction vec8; +#ifdef USE_ROCM + at::cuda::jit::NvrtcFunction vec16; +#endif + +}; + +struct JittedKernelVariantCache { + JittedVecKernelCache vec; + at::cuda::jit::NvrtcFunction noncontiguous; + at::cuda::jit::NvrtcFunction dynamic_contiguous; + at::cuda::jit::NvrtcFunction dynamic_noncontiguous; +}; + +inline c10::SmallBuffer pack_kernel_args( + std::initializer_list args, + c10::ArrayRef extra_args) { + c10::SmallBuffer ret(args.size() + extra_args.size()); + std::copy(args.begin(), args.end(), ret.data()); + std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size()); + return ret; +} + +template +void launch_jitted_unrolled_kernel( + std::mutex &jiterator_mutex, + at::cuda::jit::NvrtcFunction &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, + int64_t N, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s, + bool contiguous, + at::cuda::jit::BinaryFuncVariant scalar_pos, + const void* scalar_val, + c10::ArrayRef extra_args) { + + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + + int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type); + int bws = tws * num_threads(); + //casting result to int is always safe, intermediate is int64 and won't overflow + const uint32_t grid = (N + bws - 1) / bws; + + if (!fn_cache.function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_cache.function) { + constexpr bool dynamic_casting = !std::is_same() || + !std::is_same(); + auto code = at::cuda::jit::generate_code( + desc, contiguous, dynamic_casting, scalar_pos, tws); + fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name); + } + } + + auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u}, + {num_threads(), 1u, 1u}); +} + +template +void launch_jitted_vectorized_kernel( + std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data, + at::cuda::jit::BinaryFuncVariant scalar_pos, + const void *scalar_val, c10::ArrayRef extra_args) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + + int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type); + int bws = tws * num_threads(); + // N is still int64_t for the computation, but it's always safe to cast result to int + const uint32_t grid = (N + bws - 1) / bws; + + int vec_size = at::cuda::jit::can_vectorize_up_to( + desc, c10::ArrayRef(data.data(), data.size())); + +#ifndef USE_ROCM + const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + const int optimal_vec_size = 16 / static_cast(input_size); + vec_size = std::min(optimal_vec_size, vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + if (input_size < 2) { + vec_size = std::min(vec_size, 4); + } +#endif + + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + // fn_ptr is set to the appropriate function based on the vec size and GPU used + at::cuda::jit::NvrtcFunction* fn_ptr = nullptr; + +#ifdef USE_ROCM + if (vec_size == 16) { + fn_ptr = &fn_cache.vec16; + } else +#endif + if (vec_size == 8) { + fn_ptr = &fn_cache.vec8; + } else if (vec_size == 4) { + fn_ptr = &fn_cache.vec4; + } else if (vec_size == 2) { + fn_ptr = &fn_cache.vec2; + } else if (vec_size ==1) { + fn_ptr = &fn_cache.vec1; + } else { + TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel"); + } + + bool vectorized = vec_size > 1; + + if (!fn_ptr->function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_ptr->function) { // cache miss! + + // Generates program + auto code = at::cuda::jit::generate_code( + desc, /*contiguous=*/true, /*dynamic_casting=*/false, + scalar_pos, tws, vectorized, vec_size); + std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name; + + // Acquires the program + *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name); + } + } + + if (vectorized) { + auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); + } else { +// NVCC complains about unused variables l and s. +// It should be false positive in most cases, so we suppress the warnings. +#pragma nv_diagnostic push +#pragma nv_diag_suppress 177 + auto ic = TrivialOffsetCalculator(); + auto oc = TrivialOffsetCalculator<1>(); + auto l = memory::LoadWithoutCast(); + auto s = memory::StoreWithoutCast(); + + auto args = pack_kernel_args( + {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); +#pragma nv_diagnostic pop + } +} + +template +void jitted_gpu_kernel_generic( + std::mutex &jiterator_mutex, + JittedKernelVariantCache &cache, + const at::cuda::jit::KernelDescriptor &desc, + at::cuda::jit::BinaryFuncVariant scalar_pos, + c10::ArrayRef extra_args, + TensorIteratorBase& iter, + const bool dynamic_casting, + const void *scalar_val) { + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + constexpr int ntensors = arity + 1; + std::array data; + for (auto i : c10::irange(ntensors)) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + bool contiguous = iter.is_contiguous(); + + // Decides which of 4 kernel types to launch + // Variations are: + // - Case 1: no dynamic casting and contiguous + // - Case 2: no dynamic casting and noncontiguous + // - Case 3: dynamic casting and contiguous + // - Case 4: dynamic casting and noncontiguous + // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl + + if (!dynamic_casting) { + if (contiguous) { + // Case 1: no dynamic casting and contiguous + launch_jitted_vectorized_kernel( + jiterator_mutex, cache.vec, desc, + numel, data, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 2: no dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.noncontiguous, desc, numel, data, + input_offset_calculator, output_offset_calculator, loader, + storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Cases 3 and 4 are handled below + // Both require construction of a storer (this asserts 1 output) and one or more loaders + + // Creates store cast to output (the zeroth tensor in TensorIterator) + auto storer = memory::StoreWithCast<1>(iter); + + // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors) + auto loader = memory::LoadWithCast(iter); + + if (contiguous) { + // Case 3: dynamic casting and contiguous + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 4: dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); +} + +// NOTE: static to reduce chances of name collision. +template < + char const* name, + typename result_type, + typename f_inputs_type, + int arity, + at::cuda::jit::BinaryFuncVariant scalar_pos = + at::cuda::jit::BinaryFuncVariant::NoScalar, + typename... ExtraArgs> +static void jitted_gpu_kernel_impl( + TensorIteratorBase& iter, + const std::string &f, + const bool dynamic_casting, + at::opmath_type scalar_val, + const std::tuple& extra_args) { + + // TODO: Memory use can probably be optimized by re-using kernels across GPUs with + // the same compute capability + static std::mutex jiterator_mutex; + static std::vector device_caches(c10::cuda::device_count()); + + constexpr int nInputs = arity; + constexpr int nOutputs = 1; // TODO: Support more than 1 output + static const auto desc = at::cuda::jit::make_kernel_descriptor< + result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs); + + auto &cache = device_caches[iter.device().index()]; + auto extra_args_array = tuple_to_array(extra_args); + return jitted_gpu_kernel_generic( + jiterator_mutex, + cache, + desc, + scalar_pos, + extra_args_array, + iter, + dynamic_casting, + &scalar_val + ); +} + +} // at::native + +#endif // AT_USE_JITERATOR() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9b104a7966363c034ba5fd67a0ca34547bf08865 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh @@ -0,0 +1,1010 @@ +#pragma once + +// This file provides two functions to help write GPU elementwise kernels: +// +// gpu_kernel(TensorIterator iter, ) +// gpu_kernel_with_scalars(TensorIterator iter, ) +// +// The gpu_kernel_with_scalars generates specializations that support a +// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar +// is lifted to a kernel parameter instead of copying to device memory. +// This should be used in conjunction with TensorIterator::allow_cpu_scalars_, +// which is the default for TensorIterator::binary_op. Otherwise, all inputs +// and the output must be on the GPU. +// +// For example, to write a reciprocal kernel for GPU float Tensors: +// +// gpu_kernel(iter, []GPU_LAMBDA(float a) { +// return 1.0f / a; +// }); +// +// To write a multiplication kernel for GPU float Tensors where one argument +// may be a CPU scalar: +// +// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { +// return a * b; +// }); +// +// See BinaryOpsKernel.cu for the complete implementation +// + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __NVCC__ +#define ASSERT_HOST_DEVICE_LAMBDA(type) \ + static_assert( \ + __nv_is_extended_host_device_lambda_closure_type(type), \ + #type " must be a __host__ __device__ lambda") +#else +#define ASSERT_HOST_DEVICE_LAMBDA(type) +#endif + +namespace at::native { + +#ifdef USE_ROCM +// Custom configuration for vectorized elementwise kernel +// with template instantiation. +namespace vectorized_templated_config { +constexpr int num_threads() { + return 512; +} + +constexpr int elems_per_thread() { + return 32; +} + +constexpr int block_work_size() { + return elems_per_thread() * num_threads(); +} +} // namespace vectorized_templated_config +#endif + +template +constexpr auto sum_of_sizes(args_t args, std::index_sequence) { + if constexpr (sizeof...(Is) == 0) { + return 0; + } else { + return (sizeof(std::tuple_element_t) + ...); + } +} + +#ifdef USE_ROCM +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else if constexpr (io_sizes < 4) { + return 8; + } else { + return 4; + } +} +#else +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else { + return 8; + } +} +#endif + + +//thread work size of 8 regresses the perf of elementwise kernel on cuda +//this doesn't change ROCm behavior as thread_work_size is already 4 on ROCm +constexpr int elementwise_thread_work_size() {return 4;} +constexpr int elementwise_block_work_size() { + return elementwise_thread_work_size() * num_threads(); +} + +template +constexpr auto io_block_work_size() { + return num_threads() * elems_per_thread(); +} + +#ifdef USE_ROCM +template +constexpr auto input_size(args_t args, std::index_sequence) { + if constexpr (sizeof...(Is) == 0) { + return 0; + } else { + return sizeof(std::tuple_element_t<0, args_t>); + } +} + +template +constexpr auto calc_optimal_vec_size() { + static_assert(vec_size != 0); + static_assert(io_size != 0); + if constexpr (io_size == 1 && vec_size >= 16) { + return 16; + } else if constexpr (io_size <= 2 && vec_size >= 8) { + return 8; + } else if constexpr (io_size <= 4 && vec_size >= 4) { + return 4; + } else if constexpr (vec_size >= 4) { + return 4; + } else if constexpr (vec_size >= 2) { + return 2; + } else { + return 1; + } +} +#endif + +template +constexpr auto calc_io_size(){ + using traits = function_traits; + using args_t = typename traits::ArgsTuple; +#ifdef USE_ROCM + constexpr auto input_size = at::native::input_size(args_t{}, std::make_index_sequence>{}); + constexpr auto output_size = sizeof(typename traits::result_type); + return (input_size > 0) ? ((input_size < output_size) ? input_size : output_size) : output_size; +#else + constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence>{}); + constexpr auto output_size = sizeof(typename traits::result_type); + return input_size + output_size; +#endif +} + +#ifndef USE_ROCM +// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel +// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be +// used on sm_90 and sm_100 exclusively. +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + if constexpr (vec_size == 8) { +#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + int remaining = N - io_block_work_size() * blockIdx.x; + + if (remaining < io_block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + elems_per_thread()>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized()>(data)); + } +#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 + } else { + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + int remaining = N - io_block_work_size() * blockIdx.x; + + if (remaining < io_block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + elems_per_thread()>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized()>(data)); + } + } +} + +#else // USE_ROCM +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + using traits = function_traits; + constexpr auto io_size = calc_io_size(); +#ifdef __gfx942__ + constexpr int tws = (io_size >= 2) ? 8 : 16; +#else + constexpr int tws = elems_per_thread(); +#endif + constexpr int bws = tws * num_threads(); + int remaining = N - bws * blockIdx.x; + + if (remaining < bws) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + tws>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + constexpr auto optimal_vec_size = calc_optimal_vec_size(); + elementwise_kernel_helper( + f, memory::policies::vectorized(data)); + } +} +#endif // USE_ROCM + +template < + typename func_t, + typename array_t, + int elems_per_thread, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + int remaining = N - elems_per_thread * num_threads() * blockIdx.x; + auto policy = memory::policies:: + unroll( + data, remaining, ic, oc, l, s); + elementwise_kernel_helper(f, policy); +} + +// this function assume trivial 1d and no dynamic casting +template +static inline void launch_vectorized_kernel( + int64_t N, + const func_t& f, + array_t data) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + auto stream = at::cuda::getCurrentCUDAStream(); +#ifdef USE_ROCM + int vec_size = memory::can_vectorize_up_to(data); + c10::DeviceIndex curDevice = -1; + AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); + int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread(); +#else + using cpp_type = typename function_traits::result_type; + const uint16_t max_vec_size = memory::can_vectorize_up_to(data); + uint16_t vec_size = 16 / static_cast(sizeof(cpp_type)); + vec_size = std::min(vec_size, max_vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index()); + const int computeCapability = p->major * 10 + p->minor; + if (computeCapability != 90 && computeCapability != 100) { + vec_size = std::min(vec_size, 4); + } + if constexpr (sizeof(cpp_type) < 2) { + vec_size = std::min(vec_size, 4); + } + int tws = elems_per_thread(); +#endif + int bws = tws * num_threads(); + int64_t grid = (N + bws - 1) / bws; + switch (vec_size) { +#ifdef USE_ROCM + case 16: + vectorized_elementwise_kernel<16, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; +#endif + case 8: + vectorized_elementwise_kernel<8, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_elementwise_kernel<4, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_elementwise_kernel<2, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 1: { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + int64_t grid_unrolled = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); + unrolled_elementwise_kernel + <<>>( + N, f, data, input_calc, output_calc, loader, storer); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} + +#ifdef USE_ROCM +template < + int vec_size, + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads()) +__global__ void vectorized_templated_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t inp_calc, + out_calc_t out_calc, + loader_t loader, + storer_t storer) { + int remaining = N - + vectorized_templated_config::block_work_size() * + (gridDim.x - blockIdx.x - 1); + constexpr bool reverted_idx = true; + + if (remaining < + vectorized_templated_config::block_work_size()) { // if this block handles + // the reminder, + // just do a naive unrolled loop + auto policy = memory::policies::unroll_base< + vectorized_templated_config::num_threads(), + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + vectorized_templated_config::elems_per_thread()>( + data, remaining, inp_calc, out_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + auto policy = memory::policies::vectorized_templated< + vec_size, + array_t, + vectorized_templated_config::elems_per_thread(), + vectorized_templated_config::num_threads(), + OutputType, + InputTypes...>(data); + elementwise_kernel_helper(f, policy); + } +} + +// This function assume trivial 1d and supports template specialization +// to avoid dynamic casting. +// Input vectorization size is based on runtime information, i.e. +// the actual data types of the input and output tensor and cannot +// be determined using the functor type, as in regular non-templated +// vectorized kernels. The caller is in charge of selecting the correct input +// vectorization length. +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +static inline void launch_vectorized_templated_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / + vectorized_templated_config::block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + int vec_size = memory::can_vectorize_up_to(data); + switch (vec_size) { + case 8: + vectorized_templated_elementwise_kernel< + 8, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_templated_elementwise_kernel< + 4, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_templated_elementwise_kernel< + 2, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + // vector size 1 is not handled as part of vectorize_templated kernel + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} +#endif + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +static inline void launch_unrolled_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + + int64_t grid = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel(int N, func_t f) { + int tid = threadIdx.x; + int nv = nt * vt; + int idx = nv * blockIdx.x + tid; +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void launch_legacy_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#ifdef USE_ROCM +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel_manual_unroll(int N, func_t f) { + int tid = threadIdx.x; + constexpr int nv = nt * vt; + int idx = nv * blockIdx.x + tid; + if ((idx + nt*(vt-1)) < N) { + f(idx, true); + } else { +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx, false); + idx += nt; + } + } + } +} + +template +static void launch_legacy_kernel_manual_unroll(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel_manual_unroll<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} +#endif + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, i, Indices{}); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::fetch_and_cast::type>( + dtypes[I], data[I] + i * strides[I])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, dtypes, i, Indices{}); +} + +template +void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { + return launch_vectorized_kernel(numel, f, data); + } + auto offset_calc = ::make_offset_calculator(iter); +#ifndef USE_ROCM + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; + launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data[1], &offsets[1], 1); + }); +#else + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 4 : 8; + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + if constexpr (unroll_factor == 4) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx+grp_sz); + auto offsets2 = offset_calc.get(idx+grp_sz*2); + auto offsets3 = offset_calc.get(idx+grp_sz*3); + arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); + arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); + arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); + arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); + auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); + auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); + auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); + auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); + *out0 = tmp0; + *out1 = tmp1; + *out2 = tmp2; + *out3 = tmp3; + } else { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx+grp_sz); + auto offsets2 = offset_calc.get(idx+grp_sz*2); + auto offsets3 = offset_calc.get(idx+grp_sz*3); + auto offsets4 = offset_calc.get(idx+grp_sz*4); + auto offsets5 = offset_calc.get(idx+grp_sz*5); + auto offsets6 = offset_calc.get(idx+grp_sz*6); + auto offsets7 = offset_calc.get(idx+grp_sz*7); + arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); + arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); + arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); + arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); + arg0_t* out4 = (arg0_t*)(data[0] + offsets4[0]); + arg0_t* out5 = (arg0_t*)(data[0] + offsets5[0]); + arg0_t* out6 = (arg0_t*)(data[0] + offsets6[0]); + arg0_t* out7 = (arg0_t*)(data[0] + offsets7[0]); + auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); + auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); + auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); + auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); + auto tmp4 = invoke(f, &data[1], &offsets4[1], 1); + auto tmp5 = invoke(f, &data[1], &offsets5[1], 1); + auto tmp6 = invoke(f, &data[1], &offsets6[1], 1); + auto tmp7 = invoke(f, &data[1], &offsets7[1], 1); + *out0 = tmp0; + *out1 = tmp1; + *out2 = tmp2; + *out3 = tmp3; + *out4 = tmp4; + *out5 = tmp5; + *out6 = tmp6; + *out7 = tmp7; + } + } else { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data[1], &offsets[1], 1); + } + }); +#endif +} + +#ifdef USE_ROCM +namespace { +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity, + size_t arg_num = 0> +struct check_binary_functor_types_for_specialization { + constexpr static inline bool check() { + if constexpr (arity != 2) + return false; + if constexpr (arg_num == 0) { + using SelectedType = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } else if constexpr (arg_num == 1) { + using SelectedType2 = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } + return false; + } +}; + +// Bottom case: if we got this far, assume correct type matching except +// when there are no arguments (arity == 0). +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity> +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arity> { + constexpr static inline bool check() { + if constexpr (arity != 0) + return true; + return false; + } +}; + +template +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + 0, + 0> { + constexpr static inline bool check() { + return false; + } +}; + +// The following is a list of type specializations for vectorized_templated +// elementwise kernel. The three types refer to runtime types of the output +// tensor, first tensor argument, and the second tensor argument used for a +// binary functor. +constexpr std::array rt_binary_specializations = { + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value})}; + +bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) { + if (iter.ninputs() != 2) + return false; + for (auto spec : rt_binary_specializations) + if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] && + iter.input_dtype(1) == spec[2]) + return true; + return false; +} + +template +struct type_specialized_kernel_launcher { + template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> + static void apply( + ScalarType ret_t, + ScalarType arg0_t, + ScalarType arg1_t, + int64_t numel, + func_t f, + array_t data, + inp_calc_t input_offset_calculator, + out_calc_t output_offset_calculator, + loader_t loader, + storer_t storer) { + if (ret_t == rt_binary_specializations[arg_index][0] && + arg0_t == rt_binary_specializations[arg_index][1] && + arg1_t == rt_binary_specializations[arg_index][2]) + launch_vectorized_templated_kernel< + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + decltype(c10::impl::ScalarTypeToCPPType< + rt_binary_specializations[arg_index][0]>::t), + decltype(c10::impl::ScalarTypeToCPPType< + rt_binary_specializations[arg_index][1]>::t), + decltype(c10::impl::ScalarTypeToCPPType< + rt_binary_specializations[arg_index][2]>::t)>( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + } +}; + +} // namespace +#endif + +template +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { + if (!needs_dynamic_casting::check(iter)) { + return gpu_kernel_impl_nocast(iter, f); + } + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { +#ifdef USE_ROCM + // Attempt to call specialized vectorized elementwise kernel + // that enables interleaving. + if (check_binary_rt_types_for_specialization(iter) && + memory::can_vectorize_up_to(data) > 1) { + // constexpr to reduce the amount of kernels generated for + // vectorized templated elementwise and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr ( + std::is_same_v && traits::arity == 2 && + check_binary_functor_types_for_specialization< + func_tuple, + float, + float, + traits::arity, + /*arg_num=*/0>::check()) { + // If we got here, we know we are in one of the specialized cases. We + // need to translate the runtime type to a statically known type. This + // is effectively hoisting to the host the switch over runtime type in + // the kernel in fetch_and_cast. Loader, storer, offset calculators are + // only needed for the reminder loop. + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + memory::detail::static_unroll< + type_specialized_kernel_launcher, + rt_binary_specializations.size()>:: + with_args( + iter.dtype(0), + iter.input_dtype(0), + iter.input_dtype(1), + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + return; + } + } + std::array dtypes; + auto inner_strides = iter.get_inner_strides(); + std::array strides; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + strides[i] = inner_strides[i]; + } + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + void* out0 = data[0] + strides[0] * idx; + void* out1 = data[0] + strides[0] * (idx + grp_sz); + void* out2 = data[0] + strides[0] * (idx + grp_sz * 2); + void* out3 = data[0] + strides[0] * (idx + grp_sz * 3); + arg0_t result0 = invoke(f, &data[1], &strides[1], &dtypes[1], idx); + arg0_t result1 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz)); + arg0_t result2 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 2)); + arg0_t result3 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 3)); + c10::cast_and_store(dtypes[0], out0, result0); + c10::cast_and_store(dtypes[0], out1, result1); + c10::cast_and_store(dtypes[0], out2, result2); + c10::cast_and_store(dtypes[0], out3, result3); + } else { + void* out = data[0] + strides[0] * idx; + arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx); + c10::cast_and_store(dtypes[0], out, result); + } + }); +#else + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_unrolled_kernel( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); +#endif + } else { + std::array dtypes; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + } + auto offset_calc = ::make_offset_calculator(iter); + launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out, result); + }); + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..d47a7fa776f1b681b26dc5ec8b4548604d359946 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +struct TupleInfoCPU { + template + using tuple = thrust::tuple; + + template + static constexpr auto tie(Types&... args) noexcept { + return thrust::tie(args...); + } +}; + +template +using CompositeRandomAccessorCPU = + CompositeRandomAccessor; + +template +void swap( + references_holder rh1, + references_holder rh2 +) { + return thrust::swap(rh1.data(), rh2.data()); +} + +template +auto get(references_holder rh) -> decltype(thrust::get(rh.data())) { + return thrust::get(rh.data()); +} + +}} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Copy.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..f4b90b7b513235d3c25d966ac049d9a63a8fa1da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Copy.h @@ -0,0 +1,11 @@ +#pragma once + +namespace at { +struct TensorIteratorBase; + +namespace native { + +void direct_copy_kernel_cuda(TensorIteratorBase& iter); + +} +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h new file mode 100644 index 0000000000000000000000000000000000000000..06276c72c53acca0f5c40cb8fc142a32669217d4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h @@ -0,0 +1,494 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native::detail { + +// Enum representing the FFT type +enum class CuFFTTransformType : int8_t { + C2C, // Complex-to-complex + R2C, // Real-to-complex + C2R, // Complex-to-real +}; + +// This struct is used to let us easily compute hashes of the +// parameters. +// It will be the **key** to the plan cache. +struct CuFFTParams +{ + int64_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3 + // These include additional batch dimension as well. + int64_t sizes_[max_rank + 1]; + int64_t input_strides_[max_rank + 1]; + int64_t output_strides_[max_rank + 1]; + CuFFTTransformType fft_type_; + ScalarType value_type_; + + CuFFTParams() = default; + + CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides, + IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) { + // Padding bits must be zeroed for hashing + memset(this, 0, sizeof(*this)); + signal_ndim_ = signal_sizes.size() - 1; + fft_type_ = fft_type; + value_type_ = value_type; + + TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size()); + TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size()); + TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank); + + std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_); + std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_); + std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_); + } +}; + +static_assert(std::is_trivial_v ); + +// Returns true if the transform type has complex input +inline bool cufft_complex_input(CuFFTTransformType type) { + switch (type) { + case CuFFTTransformType::C2C: + case CuFFTTransformType::C2R: + return true; + + case CuFFTTransformType::R2C: + return false; + } + TORCH_INTERNAL_ASSERT(false); +} + +// Returns true if the transform type has complex output +inline bool cufft_complex_output(CuFFTTransformType type) { + switch (type) { + case CuFFTTransformType::C2C: + case CuFFTTransformType::R2C: + return true; + + case CuFFTTransformType::C2R: + return false; + } + TORCH_INTERNAL_ASSERT(false); +} + +// Create transform type enum from bools representing if input and output are complex +inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) { + if (complex_input && complex_output) { + return CuFFTTransformType::C2C; + } else if (complex_input && !complex_output) { + return CuFFTTransformType::C2R; + } else if (!complex_input && complex_output) { + return CuFFTTransformType::R2C; + } + TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported"); +} + + +class CuFFTHandle { + ::cufftHandle handle_; +public: + + CuFFTHandle() { + CUFFT_CHECK(cufftCreate(&handle_)); + } + + ::cufftHandle & get() { return handle_; } + const ::cufftHandle & get() const { return handle_; } + + ~CuFFTHandle() { +// Not using fftDestroy() for rocFFT to work around double freeing of handles +#if !defined(USE_ROCM) + cufftDestroy(handle_); +#endif + } +}; + +__forceinline__ +static bool is_pow_of_two(int64_t x) { + return (x & (x - 1)) == 0; +} + +using cufft_size_type = long long int; + +using CuFFTDimVector = c10::SmallVector; + +// Struct representing a tensor in CuFFT's data layout for planning transforms +// See NOTE [ cuFFT Embedded Strides ]. +struct CuFFTDataLayout { + CuFFTDimVector embed; + cufft_size_type stride, dist; + bool must_clone, simple; +}; + +// Returns a cufft embedding for a contiguous signal of the given size. +// e.g. if the input is cloned, this will be the resulting data layout +// See NOTE [ cuFFT Embedded Strides ]. +inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) { + CuFFTDataLayout layout; + layout.simple = true; + layout.must_clone = false; + layout.embed.assign(sizes.cbegin() + 1, sizes.cend()); + if (onesided) { + layout.embed.back() = sizes.back() / 2 + 1; + } + layout.stride = 1; + layout.dist = 1; + for (const auto& len : layout.embed) { + layout.dist *= len; + } + return layout; +} + +// Convert strides to a CuFFT embedded representation. +// If strides cannot be embedded, returns a simple layout and sets must_clone flag +// See NOTE [ cuFFT Embedded Strides ]. +inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) { + const auto signal_ndim = strides.size() - 1; + CuFFTDataLayout layout; + auto last_stride = strides[signal_ndim]; + layout.must_clone = (last_stride <= 0); + + const auto last_dim_size = onesided ? + sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim]; + const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size; + + // Zero stides are not allowed, even if the batch size is one. + // If that happens just set a dummy case + if (sizes[0] == 1) { + layout.dist = signal_numel; + } else if (strides[0] == 0) { + layout.must_clone = true; + } else { + layout.dist = strides[0]; + } + + // Calculate the embedding shape, or set must_clone if the strides cannot be embedded + layout.embed.resize(signal_ndim); + for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) { + auto stride = strides[i]; + if (sizes[i] == 1) { + layout.embed[i] = 1; + } else if (stride > 0 && stride % last_stride == 0) { + layout.embed[i] = stride / last_stride; + last_stride = stride; + } else { + layout.must_clone = true; + } + } + + if (layout.must_clone) { + // If the input needs to be cloned, assume it will be contiguous + layout = cufft_simple_embed(sizes, onesided); + layout.must_clone = true; + } else { + layout.embed[0] = sizes[1]; + layout.stride = strides[signal_ndim]; + // Determine if layout represents a simple embedding (contiguous data) + layout.simple = [&] { + for (const auto i : c10::irange(1, signal_ndim - 1)) { + if (layout.embed[i] != sizes[i + 1]) { + return false; + } + } + + return (layout.stride == 1 && layout.dist == signal_numel && + layout.embed.back() == last_dim_size); + }(); + } + return layout; +} + +// This class contains all the information needed to execute a cuFFT plan: +// 1. the plan +// 2. whether to clone input before executing the plan +// 3. the workspace size needed +// +// This class will be the **value** in the plan cache. +// It **owns** the raw plan via a unique_ptr. +class CuFFTConfig { +public: + + // Only move semantics is enought for this class. Although we already use + // unique_ptr for the plan, still remove copy constructor and assignment op so + // we don't accidentally copy and take perf hit. + CuFFTConfig(const CuFFTConfig&) = delete; + CuFFTConfig& operator=(CuFFTConfig const&) = delete; + + explicit CuFFTConfig(const CuFFTParams& params): + CuFFTConfig( + IntArrayRef(params.input_strides_, params.signal_ndim_ + 1), + IntArrayRef(params.output_strides_, params.signal_ndim_ + 1), + IntArrayRef(params.sizes_, params.signal_ndim_ + 1), + params.fft_type_, + params.value_type_) {} + + // For complex types, strides are in units of 2 * element_size(dtype) + // sizes are for the full signal, including batch size and always two-sided + CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides, + IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype): + fft_type_(fft_type), value_type_(dtype) { + + // signal sizes (excluding batch dim) + CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end()); + + // input batch size + const int64_t batch = sizes[0]; + const int64_t signal_ndim = sizes.size() - 1; + + // Since cuFFT has limited non-unit stride support and various constraints, we + // use a flag to keep track throughout this function to see if we need to + // input = input.clone(); + +#if defined(USE_ROCM) + // clone input to avoid issues with hipfft clobering the input and failing tests + clone_input = true; +#else + clone_input = false; +#endif + + // For half, base strides on the real part of real-to-complex and + // complex-to-real transforms are not supported. Since our output is always + // contiguous, only need to check real-to-complex case. + if (dtype == ScalarType::Half) { + // cuFFT on half requires compute capability of at least SM_53 + auto dev_prop = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3), + "cuFFT doesn't support signals of half type with compute " + "capability less than SM_53, but the device containing input half " + "tensor only has SM_", dev_prop->major, dev_prop->minor); + for (const auto i : c10::irange(signal_ndim)) { + TORCH_CHECK(is_pow_of_two(sizes[i + 1]), + "cuFFT only supports dimensions whose sizes are powers of two when" + " computing in half precision, but got a signal size of", + sizes.slice(1)); + } + clone_input |= in_strides.back() != 1; + } + + CuFFTDataLayout in_layout; + if (clone_input) { + in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R); + } else { + in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R); + } + auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C); + TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding"); + clone_input |= in_layout.must_clone; + + // Check if we can take advantage of simple data layout. + // + // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. + + const bool simple_layout = in_layout.simple && out_layout.simple; + cudaDataType itype, otype, exec_type; + const auto complex_input = cufft_complex_input(fft_type); + const auto complex_output = cufft_complex_output(fft_type); + if (dtype == ScalarType::Float) { + itype = complex_input ? CUDA_C_32F : CUDA_R_32F; + otype = complex_output ? CUDA_C_32F : CUDA_R_32F; + exec_type = CUDA_C_32F; + } else if (dtype == ScalarType::Double) { + itype = complex_input ? CUDA_C_64F : CUDA_R_64F; + otype = complex_output ? CUDA_C_64F : CUDA_R_64F; + exec_type = CUDA_C_64F; + } else if (dtype == ScalarType::Half) { + itype = complex_input ? CUDA_C_16F : CUDA_R_16F; + otype = complex_output ? CUDA_C_16F : CUDA_R_16F; + exec_type = CUDA_C_16F; + } else { + TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype); + } + + // disable auto allocation of workspace to use THC allocator + CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0)); + + size_t ws_size_t; + + // make plan + if (simple_layout) { + // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL. + // In such case, cuFFT ignores istride, ostride, idist, and odist + // by assuming istride = ostride = 1. + // + // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu. + CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype, + batch, &ws_size_t, exec_type)); + } else { + CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), + in_layout.embed.data(), in_layout.stride, in_layout.dist, itype, + out_layout.embed.data(), out_layout.stride, out_layout.dist, otype, + batch, &ws_size_t, exec_type)); + } + ws_size = static_cast(ws_size_t); + } + + const cufftHandle &plan() const { return plan_ptr.get(); } + + CuFFTTransformType transform_type() const { return fft_type_; } + ScalarType data_type() const { return value_type_; } + bool should_clone_input() const { return clone_input; } + int64_t workspace_size() const { return ws_size; } + +private: + CuFFTHandle plan_ptr; + bool clone_input; + int64_t ws_size; + CuFFTTransformType fft_type_; + ScalarType value_type_; +}; + +#if defined(USE_ROCM) + // Note that the max plan number for CUDA version < 10 has to be 1023 + // due to a bug that fails on the 1024th plan + constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023; + constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM; +#else + constexpr int64_t CUFFT_MAX_PLAN_NUM = std::numeric_limits::max(); + // The default max cache size chosen for CUDA version > 10 is arbitrary. + // This number puts a limit on how big of a plan cache should we maintain by + // default. Users can always configure it via cufft_set_plan_cache_max_size. + constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = 4096; +#endif +static_assert(0 <= CUFFT_MAX_PLAN_NUM && CUFFT_MAX_PLAN_NUM <= std::numeric_limits::max(), + "CUFFT_MAX_PLAN_NUM not in size_t range"); +static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM, + "CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range"); + +// This cache assumes that the mapping from key to value never changes. +// This is **NOT** thread-safe. Please use a mutex when using it **AND** the +// value returned from try_emplace_value. +// The contract of using this cache is that try_emplace_value should only be +// used when the max_size is positive. +class CuFFTParamsLRUCache { +public: + using kv_t = typename std::pair; + using map_t = typename std::unordered_map, + typename std::list::iterator, + ParamsHash, + ParamsEqual>; + using map_kkv_iter_t = typename map_t::iterator; + + + CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {} + + CuFFTParamsLRUCache(int64_t max_size) { + _set_max_size(max_size); + } + + CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept : + _usage_list(std::move(other._usage_list)), + _cache_map(std::move(other._cache_map)), + _max_size(other._max_size) {} + + CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept { + _usage_list = std::move(other._usage_list); + _cache_map = std::move(other._cache_map); + _max_size = other._max_size; + return *this; + } + + // If key is in this cache, return the cached config. Otherwise, emplace the + // config in this cache and return it. + // Return const reference because CuFFTConfig shouldn't be tampered with once + // created. + const CuFFTConfig &lookup(CuFFTParams params) { + AT_ASSERT(_max_size > 0); + + map_kkv_iter_t map_it = _cache_map.find(params); + // Hit, put to list front + if (map_it != _cache_map.end()) { + _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second); + return map_it->second->second; + } + + // Miss + // remove if needed + if (_usage_list.size() >= _max_size) { + auto last = _usage_list.end(); + last--; + _cache_map.erase(last->first); + _usage_list.pop_back(); + } + + // construct new plan at list front, then insert into _cache_map + _usage_list.emplace_front(std::piecewise_construct, + std::forward_as_tuple(params), + std::forward_as_tuple(params)); + auto kv_it = _usage_list.begin(); + _cache_map.emplace(std::piecewise_construct, + std::forward_as_tuple(kv_it->first), + std::forward_as_tuple(kv_it)); + return kv_it->second; + } + + void clear() { + _cache_map.clear(); + _usage_list.clear(); + } + + void resize(int64_t new_size) { + _set_max_size(new_size); + auto cur_size = _usage_list.size(); + if (cur_size > _max_size) { + auto delete_it = _usage_list.end(); + for (size_t i = 0; i < cur_size - _max_size; i++) { + delete_it--; + _cache_map.erase(delete_it->first); + } + _usage_list.erase(delete_it, _usage_list.end()); + } + } + + size_t size() const { return _cache_map.size(); } + + size_t max_size() const noexcept { return _max_size; } + + std::mutex mutex; + +private: + // Only sets size and does value check. Does not resize the data structures. + void _set_max_size(int64_t new_size) { + // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since + // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check + // first. + TORCH_CHECK(new_size >= 0, + "cuFFT plan cache size must be non-negative, but got ", new_size); + TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM, + "cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size); + _max_size = static_cast(new_size); + } + + std::list _usage_list; + map_t _cache_map; + size_t _max_size; +}; + +// Since ATen is separated into CPU build and CUDA build, we need a way to call +// these functions only when CUDA is loaded. We use CUDA hooks for this purpose +// (at cuda/detail/CUDAHooks.cpp), and call the hooked functions from the actual +// native function counterparts (at native/SpectralOps.cpp), i.e., +// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size +// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache. +int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index); +void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size); +int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index); +void cufft_clear_plan_cache_impl(DeviceIndex device_index); + +} // namespace at::native::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..f20baa9568661037954ccd6e8c425018969175a4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace at { namespace native { + +// This means that max dim is 3 + 2 = 5 with batch dimension and possible +// complex dimension +constexpr int max_rank = 3; + +static inline std::string _cudaGetErrorEnum(cufftResult error) +{ + switch (error) + { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; +#if !defined(USE_ROCM) + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; +#endif + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + default: + std::ostringstream ss; + ss << "unknown error " << error; + return ss.str(); + } +} + +static inline void CUFFT_CHECK(cufftResult error) +{ + if (error != CUFFT_SUCCESS) { + std::ostringstream ss; + ss << "cuFFT error: " << _cudaGetErrorEnum(error); + TORCH_CHECK(false, ss.str()); + } +} + +}} // at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh new file mode 100644 index 0000000000000000000000000000000000000000..15db32850d989a7b5e781500e525356834b44e13 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh @@ -0,0 +1,25 @@ +#pragma once + +namespace at::native { +#if defined(USE_ROCM) +// take these out when ROCm implements std:: math functions +#include +template +static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); + +template <> +__forceinline__ __device__ float device_sqrt(float val) { + return ::sqrtf(val); +} + +template <> +__forceinline__ __device__ double device_sqrt(double val) { + return ::sqrt(val); +} +#else +template +__forceinline__ __device__ double device_sqrt(scalar_t val) { + return std::sqrt(val); +} +#endif +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..1fc9195ac53769ff06d85fbd990c542bd34593da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h @@ -0,0 +1,697 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +// launch bounds used for kernels utilizing TensorIterator +const uint32_t block_size_bound = 256; +const uint32_t grid_size_bound = 4; +// At the time of writing, there is no curand_* call that increments the offset by more than 4. +// See: https://docs.nvidia.com/cuda/archive/11.8.0/curand/group__DEVICE.html +const uint32_t max_generator_offsets_per_curand_call = 4; + +// utility function that calculates proper philox_offset +// for distributions utilizing TensorIterator. For distributions using +// TensorIterator, we are using a grid-stride loop with each +// thread yielding one element per thread. For the edge of the grid-stride +// loop, if the tensor size is large, the unroll loop will kick in and the float4 +// from curand4 will start getting utilized (for common tensor sizes, we end up +// using rand.x from each thread). The philox_offset calculation was changed to +// (number of elements per thread * maximum generator increment per "curand_*" call), which makes +// sure that philox offset increment is not less than the number of randoms used +// in each thread. +std::tuple calc_execution_policy(const int64_t total_elements, const uint32_t unroll_factor) { + const uint64_t numel = static_cast(total_elements); + const uint32_t block_size = block_size_bound; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min( + static_cast(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm, + grid.x); + //number of times random will be generated per thread, to offset philox counter in thc random state + uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll_factor) + 1) * max_generator_offsets_per_curand_call; + return std::make_tuple(counter_offset, grid, dim_block); +} + +// grid stride loop kernel for distributions +template +C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) +__global__ void distribution_elementwise_grid_stride_kernel(int64_t numel, + PhiloxCudaState philox_args, + const dist_t dist_func, + const transform_t transform_func) { + auto [seed, offset] = at::cuda::philox::unpack(philox_args); + int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, offset, &state); + + int64_t rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * + blockDim.x * gridDim.x * unroll_factor; + for(int64_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { + auto rand = dist_func(&state); + #pragma unroll + for (int ii = 0; ii < unroll_factor; ii++) { + int64_t li = linear_index + blockDim.x * gridDim.x * ii; + if (li < numel) { + transform_func(li, static_cast((&rand.x)[ii])); + } + } + __syncthreads(); + } +} + +/** + * distribution_nullary_kernel is analogous to gpu_kernel in + * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses + * TensorIterator to launch a kernel. However, the differences are + * - it launches a grid-stride loop based kernel. The kernel is not + * generic like elementwise_kernel in Loops.cuh and is specialized + * for the distribution kernels here. + * - For big size tensors, we can launch multiple kernels recursively + * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox + * offset calculation is done in this function. + * + * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh + * to have grid-stride loop kernel and then use that to launch our distribution + * kernels? Note that we need a grid-stride loop kernel because, we found by testing + * that it achieves peak effective bandwidth. + */ +template +void distribution_nullary_kernel(at::TensorIteratorBase& iter, + RNG gen, + const dist_t& dist_func, + const transform_t transform_func) { + const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t); + TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1."); + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + auto [counter_offset, grid, block] = calc_execution_policy(numel, unroll_factor); + PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_nullary_kernel(sub_iter, + gen, dist_func, transform_func); + } + return; + } + + char* out_data = (char*)iter.data_ptr(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + if (iter.is_trivial_1d()) { + auto strides = iter.get_inner_strides(); + int stride0 = strides[0]; + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + scalar_t* out = (scalar_t*)&out_data[stride0 * idx]; + *out = transform_func(rand); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + auto offset_calc = make_offset_calculator<1>(iter); + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + auto offsets = offset_calc.get(idx); + scalar_t* out = (scalar_t*)&out_data[offsets[0]]; + *out = transform_func(rand); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +// Binary kernel +template +__global__ void distribution_binary_elementwise_kernel( + int numel, + func_t f, + PhiloxCudaState philox_args, + typename function_traits::result_type *output_data, + const typename function_traits::template arg<1>::type *input_data_1, + const typename function_traits::template arg<2>::type *input_data_2, + inp_offset_calc_t inp_calc, + out_offset_calc_t out_calc) { + auto seeds = at::cuda::philox::unpack(philox_args); + + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + + input_t_1 inputs_1[thread_work_size()]; + input_t_2 inputs_2[thread_work_size()]; + + int base_index = block_work_size() * blockIdx.x; + int remaining = std::min(numel - base_index, block_work_size()); + + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // load data into registers + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = inp_calc.get(input_idx); + inputs_1[i] = input_data_1[offsets[0]]; + inputs_2[i] = input_data_2[offsets[1]]; + + thread_idx += num_threads(); + } + + // compute and store + thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = out_calc.get(input_idx); + output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]); + thread_idx += num_threads(); + } +} + +template +void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) { + static_assert(std::is_same_v::template arg<0>::type, curandStatePhilox4_32_10_t&>, "the first argument of functor must be curandStatePhilox4_32_10_t"); + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + using output_t = typename function_traits::result_type; + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_binary_kernel(sub_iter, philox_args, f); + } + return; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); + + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + output_t *output_data = static_cast(iter.data_ptr(0)); + const input_t_1 *input_data_1 = static_cast(iter.data_ptr(1)); + const input_t_2 *input_data_2 = static_cast(iter.data_ptr(2)); + + int64_t grid = (numel + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + + if (iter.is_contiguous()) { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace +}} // namespace at::native + + +namespace at { +namespace native { +namespace templates { +namespace cuda { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) { +#ifdef FBCODE_CAFFE2 + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { + if (( + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) && range >= 1ULL << 32) + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +#else + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { + if (range >= 1ULL << 28) // allow approx 5% skew in uniform int generation using % + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +#endif +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] { + if (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int_full_range(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] { + if (std::is_same_v || std::is_same_v) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [] __device__ (uint32_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, RNG gen) { + random_kernel(iter, gen); + } +}; + +// ==================================================================================================================== + +template +void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same_v) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_uniform2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_uniform4(state); }, + transform); + } +} + +template +void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same_v) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_normal2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_normal4(state); }, + transform); + } +} + +// ==================================================== Normal ======================================================== + +template +void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) { + auto iter = TensorIterator::borrowing_nullary_op(self); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda to multiply std and add mean + auto normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::normal(rand, mean, std)); + }; + normal_and_transform(iter, gen, normal_func); + }); +} + +template +struct NormalKernel { + void operator()(const TensorBase &self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================== + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] { + auto from = static_cast(from_); + auto to = static_cast(to_); + using opmath_t = at::opmath_type; + auto range = static_cast(to-from); + // define lambda to reverse bounds, multiply 'range' and add 'from_' + auto uniform_func = [range, from, to] __device__ (opmath_t rand) { + // Compute output value before reversing the bounds + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947 + auto value = static_cast(rand * range + from); + // reverse the bounds of curand4 from (0, 1] to [0, 1) + // Note that this method is from legacy THCTensorRandom and is likely to give + // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and + // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s. + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706 + auto reverse_bound_value = value == to ? from : value; + return reverse_bound_value; + }; + uniform_and_transform(iter, gen, uniform_func); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda for log_normal transformation + auto log_normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::log_normal(transformation::normal(rand, mean, std))); + }; + normal_and_transform(iter, gen, log_normal_func); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for geometric transformation + auto geometric_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::geometric(rand, p)); + }; + uniform_and_transform(iter, gen, geometric_func); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] { + using accscalar_t = at::acc_type; + auto lambda = static_cast(lambda_); + // define lambda for exponential transformation + auto exponential_func = [lambda] __device__ (accscalar_t rand) { + return static_cast(transformation::exponential(rand, lambda)); + }; + uniform_and_transform(iter, gen, exponential_func); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] { + using accscalar_t = at::acc_type; + auto median = static_cast(median_); + auto sigma = static_cast(sigma_); + // define lambda for cauchy transformation + auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) { + return static_cast(transformation::cauchy(rand, median, sigma)); + }; + uniform_and_transform(iter, gen, cauchy_func); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ==================================================== Bernoulli ===================================================== + +template +void bernoulli_tensor_cuda_kernel( + const TensorBase &ret, const at::TensorBase &p, + PhiloxCudaState philox_args) { + auto functor = [philox_args] __device__( + int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4, + const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) { + auto seeds = at::cuda::philox::unpack(philox_args); + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // See Note [Register spilling in curand call for CUDA < 10] + float4 rand = curand_uniform4(&state); + switch (n) { + case 4: { + CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1); + v4 = static_cast(rand.w <= p4); + [[fallthrough]]; + } + case 3: { + CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1); + v3 = static_cast(rand.z <= p3); + [[fallthrough]]; + } + case 2: { + CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1); + v2 = static_cast(rand.y <= p2); + [[fallthrough]]; + } + case 1: { + CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1); + v1 = static_cast(rand.x <= p1); + } + } + }; + // The template argument `4` below indicates that we want to operate on four + // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. + at::cuda::CUDA_tensor_apply2(ret, p, functor); +} + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) { + PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(10); + } + TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type()); + // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else + const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat; + auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type)); + auto p = expand_inplace(self, p_cuda); + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] { + if (std::is_same_v) { + return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs); + } else { + return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs); + } + }); +} + +template +void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for bernoulli transformation + auto bernoulli_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::bernoulli(rand, p)); + }; + uniform_and_transform(iter, gen, bernoulli_func); + }); +} + +template +struct BernoulliKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + bernoulli_kernel(iter, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}}}} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Distributions.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Distributions.h new file mode 100644 index 0000000000000000000000000000000000000000..1a34fdfdf31494faab439544578be8aaf950dc32 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Distributions.h @@ -0,0 +1,25 @@ +#pragma once + +namespace at { +struct CUDAGeneratorImpl; +struct TensorIteratorBase; +class TensorBase; + +namespace native { + +void launch_poisson_cuda_kernel( + const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen); + +void launch_gamma_kernel( + const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen); + +void launch_binomial_cuda_kernel( + TensorIteratorBase &iter, CUDAGeneratorImpl *gen); + +void launch_dirichlet_kernel(TensorIteratorBase &iter); + +void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter); + +void launch_dirichlet_grad_kernel(TensorIteratorBase &iter); + +}} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..54e9be9f7c7c559d2dade6b237cd880fdac48717 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh @@ -0,0 +1,21 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + +Tensor embedding_backward_cuda_kernel( + const Tensor &grad, + const Tensor &orig_indices, + const Tensor &sorted_indices, + const Tensor &count, + int64_t num_weights, + int padding_idx = -1, + bool mode_mean = false, + const Tensor &offset2bag = Tensor(), + const Tensor &bag_size = Tensor(), + const Tensor &per_sample_weights = Tensor()); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh new file mode 100644 index 0000000000000000000000000000000000000000..645b095c5a6e5074069a636574b925c25c9c28f5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh @@ -0,0 +1,738 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + +namespace { + +// TODO(crcrpar): Handle version bump in codegen. +// rel: +// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482 +inline void increment_version(TensorList tensors) { + for (const auto& t : tensors) { + t.unsafeGetTensorImpl()->bump_version(); + } +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListScalarListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ bool init_args( + T** args, + FusedOptimizerTensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ void load_args( + T r_args[][kILP], + T** args, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + for (int r_index = 0; r_index < depth; r_index++) { + r_args[r_index][ii] = 0; + if (i < n && i < chunk_size) { + r_args[r_index][ii] = args[r_index][i]; + } + } + } +} + +template +__device__ void store_args( + T* dst, + T* src, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const int64_t i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) + dst[i] = src[ii]; + } +} + +template +__device__ __forceinline__ void binary_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args + // has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +template +__device__ __forceinline__ void pointwise_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args + // has depth 3 + load_args<3>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +// +// Binary Functors +// +template +struct BinaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpListAlphaFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct BinaryOpScalarTensorFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + T* scalar, + opmath_t alpha) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), + // r_args has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Unary Functors +// + +template +struct ZeroFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata<1>& tl) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const auto all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + // store + load_store(args[0], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct UnaryOpFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Pointwise Functors +// + +template +struct PointwiseOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[depth - 1][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + // store + load_store(args[2], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + store_args(args[2], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + static_assert(depth == 3 || depth == 4, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + static_assert(depth == 2 || depth == 3, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + static_assert(depth == 2 || depth == 3, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + const opmath_t scalar = tl.scalar_vals[tensor_loc]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + scalar); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + scalar); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(a, b); + } +}; + +template +struct reverse_power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(b, a); + } +}; + +} // namespace +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9b08911b1d9500f200d55e031ffff7658d5491f0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace at::native { + +// std:: does not have clamp functors +template +struct minimum { + __device__ T operator()(const T& a, const T& b) const { + return (_isnan(a) || a < b) ? a : b; + } +}; + +template +struct maximum { + __device__ T operator()(const T& a, const T& b) const { + return (_isnan(a) || a > b) ? a : b; + } +}; + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh new file mode 100644 index 0000000000000000000000000000000000000000..392b97d7cd48a6566a24035fabd4fa19b5584086 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh @@ -0,0 +1,321 @@ +#pragma once +#include +#include + +namespace at::native { + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +__forceinline__ __device__ +scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } +} + +// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size, + bool align_corners, scalar_t *grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template +__forceinline__ __device__ +scalar_t clip_coordinates(scalar_t in, int clip_limit) { + return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); +} + +// clip_coordinates_set_grad works similarly to clip_coordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) { + // Note that it is important for the gradient calculation that borders + // are considered out of bounds. + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + scalar_t max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +__forceinline__ __device__ +scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// reflect_coordinates_set_grad works similarly to reflect_coordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high, + scalar_t *grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +template +__forceinline__ __device__ +scalar_t safe_downgrade_to_int_range(scalar_t x){ + // -100.0 does not have special meaning. This is just to make sure + // it's not within_bounds_2d or within_bounds_3d, and does not cause + // undefined behavior. See #35506. + if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast(x))) + return static_cast(-100.0); + return x; +} + +template +__forceinline__ __device__ +scalar_t compute_coordinates(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2*(size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2*size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +__forceinline__ __device__ +scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + +// grid_sampler_compute_source_index_set_grad works similarly to +// grid_sampler_compute_source_index except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t grid_sampler_compute_source_index_set_grad( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t *grad_in) { + scalar_t grad_clip, grad_refl; + coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl); + } else { + coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl); + } + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +__forceinline__ __device__ +bool within_bounds_2d(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +__forceinline__ __device__ +bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template +__forceinline__ __device__ +scalar_t get_value_bounded( + const scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + +template +__forceinline__ __device__ +void safe_add_2d(scalar_t *data, int h, int w, + int sH, int sW, int H, int W, + scalar_t delta, + const index_t NC_offset, + const index_t memory_span) { + if (within_bounds_2d(h, w, H, W)) { + fastAtomicAdd(data, + NC_offset + h * sH + w * sW, + memory_span, + delta, + true); + } +} + +template +__forceinline__ __device__ +void safe_add_3d(scalar_t *data, int d, int h, int w, + int sD, int sH, int sW, int D, int H, int W, + scalar_t delta, + const index_t NC_offset, + const index_t memory_span) { + if (within_bounds_3d(d, h, w, D, H, W)) { + fastAtomicAdd(data, + NC_offset + d * sD + h * sH + w * sW, + memory_span, + delta, + true); + } +} + +template +__forceinline__ __device__ +void add_value_bounded( + scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners, + const index_t NC_offset, + const index_t memory_span) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +__forceinline__ __device__ +void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.h new file mode 100644 index 0000000000000000000000000000000000000000..4da630876302696cfb611a2c6d1b19904ffd2383 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_grid_sampler_2d_forward_kernel( + const TensorBase &output, const TensorBase &input, const TensorBase &grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_3d_forward_kernel( + const TensorBase &output, const TensorBase &input, const TensorBase &grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_2d_backward_kernel( + const TensorBase &grad_input, const TensorBase &grad_grid, + const TensorBase &grad_output, const TensorBase &input, + const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +void launch_grid_sampler_3d_backward_kernel( + const TensorBase &grad_input, const TensorBase &grad_grid, + const TensorBase &grad_output, const TensorBase &input, + const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GroupMM.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GroupMM.h new file mode 100644 index 0000000000000000000000000000000000000000..1fc23207a090ddbc57801b8e8a26213e7b77337d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GroupMM.h @@ -0,0 +1,12 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void bf16bf16_grouped_mm( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GroupMMCommon.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GroupMMCommon.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a4d8a97b6fd890a3ea202b60885c86639afc23f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GroupMMCommon.cuh @@ -0,0 +1,156 @@ +#pragma once +#include + +namespace at::cuda::detail { + +using Strides = std::array; + +template < + typename DtypeA, + typename DtypeB, + typename DtypeOutput, + typename DtypeScale, + typename ProblemShape, + typename StrideA, + typename StrideB, + typename StrideOutput> +__global__ void prepare_grouped_gemm_data( + DtypeA* A, + DtypeB* B, + DtypeOutput* output, + DtypeScale* scale_A, + DtypeScale* scale_B, + DtypeA** A_ptrs, + DtypeB** B_ptrs, + DtypeOutput** output_ptrs, + DtypeScale** inputA_scale_ptrs, + DtypeScale** inputB_scale_ptrs, + ProblemShape* problem_sizes, + // Strides for cutlass, cute::Stride + StrideA* stride_A, + StrideB* stride_B, + StrideOutput* stride_output, + const int32_t* offs, + int32_t M, + int32_t N, + int32_t K, + // Original strides of the input tensors + Strides tensor_StrideA, + Strides tensor_StrideB, + Strides tensor_StrideOutput, + int64_t a_scale_stride, + int64_t b_scale_stride, + bool a_row_major = true, + bool b_row_major = false) { + int32_t tid = threadIdx.x; + int32_t delta = 0; + if (offs != nullptr) { + int32_t start = tid == 0 ? 0 : offs[tid - 1]; + delta = offs[tid] - start; + if (K < 0) { + if (!a_row_major && b_row_major) { + CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n"); + } else { + // CUTLASS cannot handle delta=0 here. + CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n"); + } + } + + // TMA transfers require global memory tensor addresses to be + // aligned to 16 bytes. + if (tid < blockDim.x - 1) { + // Check this requirement for input tensors, in case group + // addresses are increased along the dynamic dimension. + if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension + (M < 0 && !a_row_major)) { // 3D/2D: check along N dimension + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension + (N < 0 && b_row_major)) { // 3D/2D: check along N dimension + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + + // Check the same requirement for output tensor (that is always + // contiguous, and in row-major layout). + if (N < 0) { + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n"); + } + } + } + int64_t lda, ldb, ldoutput; + if (M < 0) { + // A and output is 2d + M = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1]; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + } else if (N < 0) { + N = delta; + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[1]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; + } + } else if (K < 0) { + // A, B is 2d, output is 3d + K = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * M; + inputB_scale_ptrs[tid] = scale_B + tid * N; + } + } else { + // A, B, output are 3D + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + } + problem_sizes[tid] = ProblemShape(M, N, K); + + // make_cute_packed_stride only replaces one of the stride elements with + // one the provided values in the shape arguments + // the indices of the src/dst depend on whether A/B are row-major + // so constructing shape argument with two similar lda values + // while it looks non-sensical (and it is a nonsensical shape) + // is fine for these stride construction purposes - the one that will be used + // for replacement is correct, the other one is ignored, and we don't have to + // branch on whether A/B are row-major + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); + stride_output[tid] = + cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1}); +} +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/IndexKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..49c075b376d4f646dd82a04093df84cd1c602c3b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/IndexKernel.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at::native { +/// @param maskPrefixSum[in,out] +void launch_masked_scatter_kernel( + const TensorBase &self, const TensorBase &mask, + const TensorBase &maskPrefixSum, const TensorBase &source); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/IndexKernelUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/IndexKernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..20a67dac3f62673b3afcd46124e0acfe24fd292d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/IndexKernelUtils.h @@ -0,0 +1,35 @@ + +#include +#include +#include + +namespace at::native { + +template +inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const out_ptr, char * const in_ptr, const size_t index_stride_bytes, const size_t element_size) { + using at::native::memory::get_alignment; + const auto index_element_size = iter.element_size(2); + //TensorIterator strides and sizes are ordered fastest moving to slowest moving, + //in contrast to regular sizes + // we need contiguous source and dst slices and aligned pointers and strides and slice size to do vectorized loads + // also we need idx to be expanded in the last dimension so we can copy entire slices + // and we need the src tensor to keep 0 stride from restriding + // (it could have been deleted by dimension collapse, in this case iterator would still be 2d + // but we cannot use fast path) + + return iter.ndim() == 2 && iter.strides(2)[0]==0 && iter.strides(2)[1]==index_element_size && + static_cast(iter.strides(0)[0])==element_size && + static_cast(iter.strides(1)[0])==element_size && static_cast(iter.strides(1)[1] == 0) && + get_alignment(out_ptr) == alignment && get_alignment(in_ptr) == alignment && + get_alignment(static_cast(iter.shape()[0] * element_size)) == alignment && + get_alignment(static_cast(index_stride_bytes)) == alignment && + get_alignment(static_cast(iter.strides(0)[1])) == alignment; +} + +template +void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind, + int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, + bool allow_neg_indices=false); + + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6540342fda5808954e159de06b976cb2f76a8623 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh @@ -0,0 +1,186 @@ +#pragma once + +#include + +#if AT_USE_JITERATOR() + +#include + +#include +#include +#include + +#include + +#include + +namespace at::native { + +/* Note [Jiterator] +The "jiterator" simply just-in-time compiles the same kernels that +Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time, +build size, and initial CUDA context size. + +By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels. +This behavior is controlled with two environment variables: + - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use + - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels + +The jiterator currently has some limitations, however. It cannot: + - handle math on complex datatypes + - handle kernels with scalar parameters + +These improvements will likely come soon. + +For examples of how to use the jiterator see the i1 and gcd kernel +implementations, which pass jittable strings implementing their +operations instead of the typical CUDA functors. + +To pass a runtime argument (similar to lambda captures in non-JIT kernels), +we need to pass to additional arguments to `jitted_gpu_kernel` by value. +Currently only primitive C++ types used for computation are valid. +The order of these extra arguments should be same as the order they appear +in kernel's function signature. (look at polygamma for example) + +NOTE: One big restriction being that these arguments should be after the +arguments provided by TensorIterator. Eg. While capturing `n`, where +`scalar_t x` and `scalar_t y` are provided by TensorIterator, +* foo(scalar_t x, scalar_t y, int n) works! +* foo(int n, scalar_t x, scalar_y) doesn't work +* foo(scalar_t x, int n, scalar_y) doesn't work + +*/ + +// Entrypoint for jitted GPU kernels. +// Only handles elementwise unary and binary kernels with a +// common dtype and a single output. +// NOTE: this assumes the op's iterator has a common_dtype. +// NOTE: We use std::tuple instead of parameter pack +// for `extra_args` due to following +// bug on older versions of clang +// https://bugs.llvm.org/show_bug.cgi?id=23029 +template < + char const* name, + typename return_type, + typename f_inputs_type, + int arity, + typename... Args> +void jitted_gpu_kernel( + TensorIteratorBase& iter, + const std::string& f, + at::cuda::jit::BinaryFuncVariant scalar_pos = + at::cuda::jit::BinaryFuncVariant::NoScalar, + at::opmath_type scalar_val = 0, + std::tuple extra_args = std::make_tuple()) { + // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel + // Maybe it could be refactored? + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + jitted_gpu_kernel( + sub_iter, f, scalar_pos, scalar_val, extra_args); + } + + return; + } + + // Computes if dynamic casting is needed + // Dynamic casting is needed if an input's dtype differs from the common dtype + // or if the result dtype differs from the output's dtype + // Note: this is intentionally divergent from calling needs_dynamic_casting, + // which is more general and inspects a lambda to determine if dynamic + // casting is needed. + bool needs_dynamic_casting = false; + + // Checks output + const ScalarType return_scalar_type = c10::CppTypeToScalarType::value; + const auto dtype0 = iter.dtype(0); + if (dtype0 != return_scalar_type) { + needs_dynamic_casting = true; + } + + // Checks input(s) + const ScalarType inputs_scalar_type = c10::CppTypeToScalarType::value; + for (auto i = decltype(arity){1}; i < (arity + 1); ++i) { + const auto dtypei = iter.dtype(i); + if (dtypei != inputs_scalar_type) { + needs_dynamic_casting = true; + break; + } + } + if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) { + // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used + // for computation in the generated code and hence we pass a dummy + // value of `0`. + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::cuda::jit::BinaryFuncVariant::NoScalar>( + iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args); + } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) { + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::cuda::jit::BinaryFuncVariant::RhsScalar>( + iter, + f, + needs_dynamic_casting, + scalar_val, + extra_args); + + } else { + jitted_gpu_kernel_impl< + /*name*/ name, + /*return_type=*/return_type, + /*f_inputs_type=*/f_inputs_type, + arity, + at::cuda::jit::BinaryFuncVariant::LhsScalar>( + iter, + f, + needs_dynamic_casting, + scalar_val, + extra_args); + } +} + +// TODO: support runtime state capture similar to `jitted_gpu_kernel`. +template +void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type) + using opmath_t = at::opmath_type; + if (iter.is_cpu_scalar(1)) { + auto scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + jitted_gpu_kernel(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val); + } else if (iter.is_cpu_scalar(2)) { + auto scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + jitted_gpu_kernel(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val); + } else { + jitted_gpu_kernel(iter, f); + } +} + +} // namespace at::native + +#endif // AT_USE_JITERATOR() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1696ee64eac67df0f06c27893075d775995b8ae6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh @@ -0,0 +1,367 @@ +#pragma once +#include + +#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + +// ROCm 6.3 is planned to have these functions, but until then here they are. +#if defined(USE_ROCM) && ROCM_VERSION >= 60201 +#include +#include +#include + +__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) + typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; + static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); + union { + __hip_bfloat162_raw bf162_raw; + vec_short2 vs2; + } u{static_cast<__hip_bfloat162_raw>(value)}; + u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2); + return static_cast<__hip_bfloat162>(u.bf162_raw); +#else + static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw)); + union u_hold { + __hip_bfloat162_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} + +__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16) + // The api expects an ext_vector_type of half + typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; + static_assert(sizeof(vec_fp162) == sizeof(__half2_raw)); + union { + __half2_raw h2r; + vec_fp162 fp16; + } u {static_cast<__half2_raw>(value)}; + u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16); + return static_cast<__half2>(u.h2r); +#else + static_assert(sizeof(__half2_raw) == sizeof(unsigned int)); + union u_hold { + __half2_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} +#define ATOMICADD preview_unsafeAtomicAdd +#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f) +#else +#define ATOMICADD atomicAdd +#define NATIVE_ZERO_BF16 __int2bfloat16_rz(0) +#endif + +namespace at:: native { + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +// for channels-last +__device__ __forceinline__ size_t +idx_cl( + const size_t n, const size_t h, const size_t w, const size_t c, + const size_t height, const size_t width, const size_t channel +) { + return ((n * height + h) * width + w) * channel + c; +} + +// fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization +// that speed up half-precision atomics. The situation with half +// precision atomics is that we have a slow __half atomic, and +// a fast vectored __half2 atomic (this can be worth up to a 6x +// speedup, see https://github.com/pytorch/pytorch/pull/21879). +// We can convert a __half atomic into a __half2 atomic by simply +// pairing the __half with a zero entry on the left/right depending +// on alignment... but only if this wouldn't cause an out of bounds +// access! Thus, you must specify tensor and numel so we can check +// if you would be out-of-bounds and use a plain __half atomic if +// you would be. +template < + typename scalar_t, + typename index_t, + typename std::enable_if_t>* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM) && ROCM_VERSION < 60201) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __half* target_addr = reinterpret_cast<__half*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__half2) == 0); + + if (low_byte && index < (numel - 1)) { + __half2 value2; + value2.x = static_cast<__half>(value); + value2.y = __int2half_rz(0); + ATOMICADD(reinterpret_cast<__half2*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __half2 value2; + value2.x = __int2half_rz(0); + value2.y = static_cast<__half>(value); + ATOMICADD(reinterpret_cast<__half2*>(target_addr - 1), value2); + + } else { +#ifdef USE_ROCM + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, static_cast(value)); +#else + atomicAdd( + reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value)); +#endif + } +#endif +} + +template < + typename scalar_t, + typename index_t, + typename std::enable_if_t>* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM) && ROCM_VERSION < 60201) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__nv_bfloat162) == 0); + + if (low_byte && index < (numel - 1)) { + __nv_bfloat162 value2; + value2.x = *reinterpret_cast<__nv_bfloat16*>(&value); + value2.y = NATIVE_ZERO_BF16; + ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __nv_bfloat162 value2; + value2.x = NATIVE_ZERO_BF16; + value2.y = *reinterpret_cast<__nv_bfloat16*>(&value); + ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2); + + } else { +#ifdef USE_ROCM + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, static_cast(value)); +#else + atomicAdd( + reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value)); +#endif + } +#endif +} + + +template < + typename scalar_t, + typename index_t, + typename std::enable_if_t && !std::is_same_v>* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { + gpuAtomicAddNoReturn(tensor + index, value); +} + +template +__device__ __forceinline__ void fastAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value, + bool fast_atomics) { + if (fast_atomics) { + fastSpecializedAtomicAdd(tensor, index, numel, value); + } else { + gpuAtomicAddNoReturn(tensor + index, value); + } +} + +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)) +// This function implements warp-level opportunistic fastatomics +// To reduce contention on an atomicAdd, this replaces per-thread atomicAdd with a per-warp atomicAdd. +// We identify all the threads within a warp that will perform an atomicAdd on the same destination +// address and perform the addition on the CU. Each warp elects a leader thread which does the +// atomicAdd to the destination address. +template +__device__ __forceinline__ void opportunistic_fastAtomicAdd( + scalar_t* self_ptr, + index_t index, + const index_t numel, + scalar_t value) { + + scalar_t* dst = self_ptr + index; + + //pack coalseced bf16 and fp16 + if constexpr (std::is_same::value || std::is_same::value) + { + typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; + union ill { unsigned int i[2]; int64_t il; }; + ill iil_, ill_oneUpDst = {}; + iil_.il = (int64_t)dst; + ill_oneUpDst.i[0] = __builtin_amdgcn_mov_dpp(iil_.i[0], 0x130, 0xf, 0xf, 0); + ill_oneUpDst.i[1] = __builtin_amdgcn_mov_dpp(iil_.i[1], 0x130, 0xf, 0xf, 0); + union bfi {scalar_t bf; short s; } bfi_ = { .bf = value }; bfi bfi_oneUpVal; + + bfi_oneUpVal.s = __builtin_amdgcn_mov_dpp(bfi_.s, 0x130, 0xf, 0xf, 0); + auto oneUpVal = bfi_oneUpVal.bf; + + __half* target_addr = reinterpret_cast<__half*>(self_ptr + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__half2) == 0); + bool canCombnUp = (bool)(__activemask()&(1<<(threadIdx.x+1))) && + (low_byte && index < (numel - 1)) && + (ill_oneUpDst.il - reinterpret_cast(dst) == sizeof(scalar_t)); + bool canCombnDn = (__builtin_amdgcn_mov_dpp(canCombnUp, 0x138, 0xf, 0xf, 0)); + + if (__lane_id()%2==0) + { + if (canCombnUp) { + typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; + union bfvs { scalar_t bf[2]; vec_short2 vs2; vec_fp162 df16; }; + bfvs bfvs_ = {}; + bfvs_.bf[0] = value; + bfvs_.bf[1] = oneUpVal; + if constexpr (std::is_same::value) + __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)dst, bfvs_.vs2); + else + __builtin_amdgcn_flat_atomic_fadd_v2f16((__half2*)dst, bfvs_.df16); + return; + } + } + else + { + if (canCombnDn) + return; + } + } + + // not coalsced, so now let try to capture lane-matches... + // __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd + // __match_any_sync() -- returns bit mask of the threads that have same dest addr + auto mask = __match_any_sync(__activemask(), (int64_t)dst); + + // select a leader thread + int leader = __ffsll(mask) - 1; + + scalar_t crnt_val = (scalar_t)0; + auto crnt_msk = mask >> (leader); + int crnt_idx = leader; + + // __shfl is limited in the dtypes it accepts + // That's why, we need these if/else to correctly do the addition on the CU + if constexpr(sizeof(scalar_t) <= sizeof(int)) { + union punner { int l; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l = __shfl(pnr.l ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + else if constexpr(sizeof(scalar_t) <= sizeof(long)) { + union punner { long l; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l = __shfl(pnr.l ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + else if constexpr(sizeof(scalar_t) <= sizeof(long long)) { + union punner { long long l; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l = __shfl(pnr.l ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + else { + union punner { long long l[2]; scalar_t s; }; + punner pnr = {}; + pnr.s = value; + while (crnt_msk != 0) { + if (crnt_msk & 1) { + punner add_val = {}; + add_val.l[0] = __shfl(pnr.l[0] ,crnt_idx); + add_val.l[1] = __shfl(pnr.l[1] ,crnt_idx); + crnt_val += add_val.s; + } + crnt_idx++; + crnt_msk = crnt_msk >> 1; + } + } + + + //Once the correct crnt_val is determined, only the leader thread does the update to the dest addr + if (__lane_id() == leader) { + fastAtomicAdd(self_ptr, index, numel, crnt_val, true); + } +} +#endif + +#undef ATOMICADD +#undef NATIVE_ZERO_BF16 + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..d10c3fbb446819f01cb0b1e37c51cdb01a79abea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h @@ -0,0 +1,16 @@ +#pragma once +#include + +namespace at::native { + +// returns 2**floor(log2(n)) +static int lastPow2(unsigned int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Loops.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Loops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3d127af17d641d5fa88f6ceaa764e223b8d2dc64 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Loops.cuh @@ -0,0 +1,330 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +namespace at::native { + +template +static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) { + // array size can not be 0, this happens when N == 0 + constexpr int array_size = std::max(N, 1); + TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); + std::array strides; + int64_t element_sizes[array_size]; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i + iter.noutputs()).data(); + element_sizes[i] = iter.element_size(i + iter.noutputs()); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); + std::array strides; + int64_t element_sizes[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + constexpr int elems_per_thread = policy_t::tws; + + int idx = blockIdx.x; + if constexpr (reverted_idx) + idx = gridDim.x - blockIdx.x - 1; + + return_t results[elems_per_thread]; + args_t args[elems_per_thread]; + + // load + policy.load(args, idx); + + // compute + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) { + if (policy.check_inbounds(i)) { + results[i] = c10::guts::apply(f, args[i]); + } + } + + // store + policy.store(results, idx); +} + +} // namespace at::native + +#include + +namespace at:: native { + +template +void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_nocast(sub_iter, f); + } + return; + } + + gpu_kernel_impl_nocast(iter, f); +} + +template +void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel(sub_iter, f); + } + return; + } + + gpu_kernel_impl(iter, f); +} + +template +struct AUnaryFunctor { + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + __device__ return_t operator()(arg2_t b) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} + private: + func_t f; + opmath_arg1_t a; +}; + +template +struct BUnaryFunctor { + using traits = function_traits; + using opmath_arg2_t = typename traits::template arg<1>::type; + __device__ return_t operator()(arg1_t a) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} + private: + func_t f; + opmath_arg2_t b; +}; + +// Though seemingly noop, this inserts casts from arg1_t to func_t's type +// (which may be higher precision), as well as casts to return_t +template +struct BinaryFunctor { + __device__ return_t operator()(arg1_t a, arg2_t b) const { + return f(a, b); + } + BinaryFunctor(func_t f_): f(f_) {} + private: + func_t f; +}; + +// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which +// accepts inputs at higher precision (typically opmath_t), but then +// ensure that we load from memory at the correct precision (scalar_t) +// to avoid expensive loads. For the whole sordid story see +// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302 +template +void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + using opmath_arg2_t = typename traits::template arg<1>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + + if (iter.is_cpu_scalar(1)) { + AUnaryFunctor af(f, iter.scalar_value(1)); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + gpu_kernel(iter, af); + } else if (iter.is_cpu_scalar(2)) { + BUnaryFunctor bf(f, iter.scalar_value(2)); + iter.remove_operand(2); + gpu_kernel(iter, bf); + } else { + gpu_kernel(iter, BinaryFunctor(f)); + } +} + +template +void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + // Use symmetric property of the functor to reduce number of kernels, + // requires f(a, b) == f(b, a) + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg_t = typename traits::template arg<0>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + static_assert(std::is_same_v::type>, + "f is not symmetric"); + + OptionalDeviceGuard device_guard; + opmath_arg_t scalar_val{}; + + if (iter.is_cpu_scalar(1)) { + scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + device_guard.reset_device(iter.device(1)); + } else if (iter.is_cpu_scalar(2)) { + scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + } + + if (iter.ninputs() == 2) { + gpu_kernel(iter, BinaryFunctor(f)); + } else { + AUnaryFunctor unary_f(f, scalar_val); + gpu_kernel(iter, unary_f); + } +} + +// Legacy variant that assumes that func_t has the correct types +// that we expect to load from memory +template +void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; + using return_t = typename traits::result_type; + opmath_gpu_kernel_with_scalars(iter, f); +} + +namespace { // functions for `gpu_kernel_multiple_outputs`. + +// check the return type is `thrust::tuple`, not `std::tuple`. +template struct is_tuple: std::false_type {}; + +template struct is_tuple>: std::true_type {}; + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) { + int remaining = N - block_work_size() * blockIdx.x; + elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc)); +} + +template +static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using output_t = typename traits::result_type; + static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); + constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_inputs = traits::arity; + constexpr int ntensors = num_outputs + num_inputs; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + if (iter.is_contiguous()) { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator(); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } else { + auto input_calc = make_input_offset_calculator(iter); + auto output_calc = make_output_offset_calculator(iter); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } +} +} // namespace + +template +void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { + ASSERT_HOST_DEVICE_LAMBDA(func_t); + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda()); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_multiple_outputs(sub_iter, f); + } + return; + } + + gpu_kernel_multiple_outputs_impl(iter, f); +} + +} //namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Math.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Math.cuh new file mode 100644 index 0000000000000000000000000000000000000000..89308177bfea89f1b525cd76b23e17576f3e1bb5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Math.cuh @@ -0,0 +1,3390 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { +// See note [Jiterator] +// TODO: elaborate in this comment on the structure of math.cuh +#if AT_USE_JITERATOR() + +const auto ndtri_string = jiterator_stringify( + /* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates polynomial of degree N: + * + * 2 N + * y = C + C x + C x +...+ C x + * 0 1 2 N + * + * Coefficients are stored in reverse order: + * + * coef[0] = C , ..., coef[N] = C . + * N 0 + */ + template + T polevl(const T x, const T A[], const int len) { + // NOTE: This `polevl` is different from other `polevl` + // implementation (in PyTorch) which expect the `len` to be + // `len(A) - 1` instead of `len(A)`. + T result = 0; + for (int i = 0; i < len; ++i) { + result = result * x + A[i]; + } + return result; + } + + /* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes the argument, x, for which the area under the Gaussian probability density function + * (integrated from minus infinity to x) is equal to y. + */ + template + T ndtri(T y0) { + + constexpr T zero = 0; + constexpr T one = 1; + + // Handles special cases + if (y0 == zero) { + return NEG_INFINITY; + } + if (y0 == one) { + return POS_INFINITY; + } + if (y0 < zero || y0 > one) { + return NAN; + } + + bool code = true; + T y = y0; + // Note: the constant 0.135... is equal to exp(-2) + if (y > one - T{0.13533528323661269189}) { + y = one - y; + code = false; + } + + if (y > T{0.13533528323661269189}) { + /* approximation for 0 <= |y - 0.5| <= 3/8 */ + static const T P0[5] = { + -5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0, + }; + + static const T Q0[9] = { + 1.00000000000000000000E0, + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0, + }; + + /* sqrt(2pi) */ + constexpr T s2pi = 2.50662827463100050242E0; + + y = y - T{0.5}; + const T y2 = y * y; + T x = y + y * (y2 * polevl(y2, P0, int{5}) / polevl(y2, Q0, int{9})); + return x * s2pi; + } + + T x = sqrt(T{-2.} * log(y)); + const T x0 = x - (log(x) / x); + + const T z = one / x; + T x1; + + /* y > exp(-32) = 1.2664165549e-14 */ + if (x < T{8.0}) { + /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8 + * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14. + */ + static const T P1[9] = { + 4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4, + }; + + static const T Q1[9] = { + 1.00000000000000000000E0, + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4, + }; + + x1 = z * polevl(z, P1, int{9}) / polevl(z, Q1, int{9}); + } else { + /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64 + * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890. + */ + static const T P2[9] = { + 3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9, + }; + + static const T Q2[9] = { + 1.00000000000000000000E0, + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9, + }; + + x1 = z * polevl(z, P2, int{9}) / polevl(z, Q2, int{9}); + } + + x = x0 - x1; + return (!code) ? x : -x; + } +); // ndtri_string + +const auto log_ndtr_string = jiterator_stringify( + template + T log_ndtr(T x) { + constexpr T SQRT1_2{0.707106781186547524400844362104849039}; // 1/sqrt(2) + T t = x * SQRT1_2; + if (x < T{-1.0}) { + return log(erfcx(-t) / 2) - t * t; + } else { + return log1p(-erfc(t) / 2); + } + } +); // log_ndtr_string + +const auto gcd_string = jiterator_stringify( + template + T gcd(const T a_in, const T b_in) { + T a = abs(a_in); + T b = abs(b_in); + + while (a != T{0}) { + T c = a; + a = b % a; + b = c; + } + + return b; + } +); // gcd_string + +const auto lcm_string = jiterator_stringify( + template + T gcd(const T a_in, const T b_in) { + T a = abs(a_in); + T b = abs(b_in); + + while (a != T{0}) { + T c = a; + a = b % a; + b = c; + } + + return b; + } + + template + T lcm(const T a, const T b) { + T g = gcd(a, b); + return (g == T{0}) ? T{0} : abs(a / g * b); + } +); // lcm_string + +/* + * For licensing information, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma +const auto digamma_string = jiterator_stringify( + template + T digamma(T x) { + static const double PI_f64 = 3.14159265358979323846; + + // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard + if (x == 0) { + return copysign(POS_INFINITY, -x); + } + + T result = 0; + if (x < 0) { + // Short-circuits if x is a negative integer and returns NaN + // per the C++ standard + const bool x_is_integer = (x == trunc(x)); + if (x_is_integer) { + return NAN; + } + + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = modf(static_cast(x), &q); + result = - PI_f64 / tan(PI_f64 * r); + x = 1 - x; + } + + while (x < T{10}) { + result -= T{1} / x; + x += T{1}; + } + + if (x == T{10}) { + return result + T{2.25175258906672110764}; + } + + T y = 0; + if (x < T{1.0e17}) { + const T A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + + T z = T{1} / (x * x); + + T polevl_result = 0; + for (int i = 0; i <= 6; i++) { + polevl_result = polevl_result * z + A[i]; + } + y = z * polevl_result; + } + + return log(x) - (T{0.5} / x) - y + result; + } +); // digamma_string + +/* + * This function is derived from the implementation of the zeta function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +const auto zeta_string = jiterator_stringify( + template + T zeta(T x, T q) { + const T MACHEP{1.11022302462515654042E-16}; + constexpr T zero{0}; + constexpr T half{0.5}; + constexpr T one{1}; + static const T A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + int i = 0; + T a, b, k, s, t, w; + + // Short-circuits x -> +infty + if (x == one) { + return POS_INFINITY; + } + + // Short-circuits x < 1 -> NaN + if (x < one) { + return NAN; + } + + // Short-circuits negative q integers map to +infty, + // negative q non-integers map to NaN + if (q <= zero) { + if (q == floor(q)) { + return POS_INFINITY; + } + if (x != floor(x)) { + return NAN; + } + } + + s = pow(q, -x); + a = q; + i = 0; + b = zero; + while ((i < 9) || (a <= T{9.0})) { + i += 1; + a += one; + b = pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return s; + } + }; + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = fabs(t / s); + + if (t < MACHEP) { + return s; + } + + k += one; + a *= x + k; + b /= w; + k += one; + } + + return s; + } +); // zeta_string + +const auto trigamma_string = jiterator_stringify( + template + T trigamma(T x) { + const T PI{3.14159265358979323846}; + T sign = 1; + T result = 0; + + if (x < T{0.5}) { + sign = -1; + T sin_pi_x = sin(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + + for (int i = 0; i < 6; ++i) { + result += T{1} / (x * x); + x += 1; + } + + const T one{1}; + const T ixx = one / (x*x); + result += (one + one / (T{2}*x) + ixx * (one/T{6} - ixx * (one/T{30} - ixx * (one/T{42})))) / x; + return sign * result; +} +); // trigamma_string + +const auto lgamma_string = jiterator_stringify( + template + T lgamma_kernel(T a) { + return lgamma(a); + } +); // lgamma_string + +const auto polygamma_string = zeta_string + jiterator_stringify( + template + T polygamma(T x, int n) { + // already blocked if n <= 1 + const auto one = T{1}; + return ((n % 2) ? one : -one) * exp(lgamma(static_cast(n) + one)) * + zeta(static_cast(n + 1), x); + } +); // polygamma_string + +const auto exp2_string = jiterator_stringify( + template + T exp2_impl(T a) { + return exp2(a); + } + + namespace std { template class complex; } + template + std::complex exp2_impl(std::complex x) { + // There is no std::exp2 overload for complex, so instead + // use the identity 2^x = e^(ln(2) * x) + const auto ln_2 = static_cast(0.693147180559945309417232121458176); + return exp(ln_2 * x); + } + + template + T exp2_kernel(T a) { + return exp2_impl(a); + } +); // exp2_string + +const auto erfc_string = jiterator_stringify( + template + T erfc_kernel(T a) { + return erfc(a); + } +); // erfc_string + +const auto erfinv_string = jiterator_stringify( + template + T erfinv_kernel(T a) { + return erfinv(a); + } +); // erfinv_string + +const auto entr_string = jiterator_stringify( + template + T entr(T a) { + if (a != a) { + return a; + } + + if (a > 0) { + return -a * log(a); + } + + if (a == 0) { + return 0; + } + + return NEG_INFINITY; + } +); // entr_string + +// NOTE: `kaiser_window_string` depends on `i0_string` +// for its implementation. +const auto i0_string = jiterator_stringify( + template + T chbevl(T x, const T array[], const int len) { + + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + T i0(T _x) { + T x = fabs(_x); + + if (x <= T{8.0}) { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T A[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + T y = (x / T{2.0}) - T{2.0}; + return exp(x) * chbevl(y, A, int{30}); + } + + // Handles x > 8 case + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + const T B[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return (exp(x) * chbevl(T{32.0} / x - T{2.0}, B, int{25})) / sqrt(x); + } +); // i0_string + +const auto i1_string = jiterator_stringify( + template + T chbevl(const T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + T i1(T _x) { + const T x = fabs(_x); + + if (x <= T{8.0}) { + // Chebyshev coefficients for exp(-x) i1(x) in the internal [0, 8] + // lim(x->0){ exp(-x) i1(x) / x } = 1/2 + static const T coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + const T y = x / T{2.0} - T{2.0}; + const T out = exp(x) * x * chbevl(y, coefficients, int{29}); + return (_x < T{0.0}) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval [8, infinity] + // lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi) + static const T coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + const T out = (exp(x) * chbevl(T{32.} / x - T{2.}, coefficients, int{25})) / sqrt(x); + return (_x < T{0.}) ? -out : out; + } +); // i1_string + +const auto i1e_string = jiterator_stringify( + template + T chbevl(const T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + // See double and float instantiations below + template + T i1e(T _x) { } + + // Double specialization (uses different coefficients than the float version) + template<> + double i1e(double _x) { + const double x = fabs(_x); + if (x <= double{8.}) { + // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8]. + // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2. + static const double coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + const double y = x / double{2.} - double{2.}; + const double out = chbevl(y, coefficients, int{29}) * x; + return (_x < 0.) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval (8, infinity]. + // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi). + // TODO: what's an "inverted interval"? Open on the left + // and closed on the right? + static const double coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + const double out = chbevl(double{32.} / x - double{2.}, coefficients, int{25}) / sqrt(x); + return (_x < double{0.}) ? -out : out; + } + + // Float specialization (uses different coefficients than the double version) + template<> + float i1e(float _x) { + const float x = fabsf(_x); + if (x <= float{8.}) { + // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8]. + // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2. + static const float coefficients[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + const float y = x / float{2.} - float{2.}; + const float out = chbevl(y, coefficients, int{17}) * x; + return (_x < 0.) ? -out : out; + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval (8, infinity]. + // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi). + // TODO: what's an "inverted interval"? Open on the left + // and closed on the right? + static const float coefficients[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + const float out = chbevl(float{32.} / x - float{2.}, coefficients, int{7}) / sqrt(x); + return (_x < float{0.}) ? -out : out; + } +); // i1e_string + +const auto kaiser_window_string = i0_string + jiterator_stringify( + template + T kaiser_window(T a, T inv_alpha, T beta, T inv_i0_beta) { + T x = a * inv_alpha - T{1}; + T y = max(T{0}, T{1} - x * x); + return i0(beta * sqrt(y)) * inv_i0_beta; + } +); // kaiser_window_string + +const auto sinc_string = jiterator_stringify( + template + T sinc(T a) { + if (a == T(0)) { + return T(1); + } + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; + } +); // sinc_string + +const auto erfcx_string = jiterator_stringify( + /* The next function is taken from http://ab-initio.mit.edu/faddeeva */ + + /* Copyright (c) 2012 Massachusetts Institute of Technology + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + /* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by + Steven G. Johnson, October 2012. + + This function combines a few different ideas. + + First, for x > 50, it uses a continued-fraction expansion (same as + for the Faddeeva function, but with algebraic simplifications for z=i*x). + + Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations, + but with two twists: + + a) It maps x to y = 4 / (4+x) in [0,1]. This simple transformation, + inspired by a similar transformation in the octave-forge/specfun + erfcx by Soren Hauberg, results in much faster Chebyshev convergence + than other simple transformations I have examined. + + b) Instead of using a single Chebyshev polynomial for the entire + [0,1] y interval, we break the interval up into 100 equal + subintervals, with a switch/lookup table, and use much lower + degree Chebyshev polynomials in each subinterval. This greatly + improves performance in my tests. + + For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), + with the usual checks for overflow etcetera. + + Performance-wise, it seems to be substantially faster than either + the SLATEC DERFC function [or an erfcx function derived therefrom] + or Cody's CALERF function (from netlib.org/specfun), while + retaining near machine precision in accuracy. + */ + + /* Given y100 = 100 * y, where y = 4 / (4 + x) for x >= 0, compute erfc(x). + + Uses a look-up table of 100 different Chebyshev polynomials + for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated + with the help of Maple and a little shell script. This allows + the Chebyshev polynomials to be of significantly lower degree (about 1/4) + compared to fitting the whole [0,1] interval with a single polynomial. + */ + + // TODO: review if this is computing in double when given a float input + template + T erfcx_y100(T y100) { + switch (static_cast(y100)) { + case 0: { + T t = 2*y100 - 1; + return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t; + } + case 1: { + T t = 2*y100 - 3; + return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 2: { + T t = 2*y100 - 5; + return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 3: { + T t = 2*y100 - 7; + return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t; + } + case 4: { + T t = 2*y100 - 9; + return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t; + } + case 5: { + T t = 2*y100 - 11; + return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t; + } + case 6: { + T t = 2*y100 - 13; + return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 7: { + T t = 2*y100 - 15; + return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t; + } + case 8: { + T t = 2*y100 - 17; + return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 9: { + T t = 2*y100 - 19; + return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t; + } + case 10: { + T t = 2*y100 - 21; + return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t; + } + case 11: { + T t = 2*y100 - 23; + return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t; + } + case 12: { + T t = 2*y100 - 25; + return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 13: { + T t = 2*y100 - 27; + return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 14: { + T t = 2*y100 - 29; + return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 15: { + T t = 2*y100 - 31; + return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 16: { + T t = 2*y100 - 33; + return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 17: { + T t = 2*y100 - 35; + return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t; + } + case 18: { + T t = 2*y100 - 37; + return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t; + } + case 19: { + T t = 2*y100 - 39; + return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 20: { + T t = 2*y100 - 41; + return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 21: { + T t = 2*y100 - 43; + return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 22: { + T t = 2*y100 - 45; + return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 23: { + T t = 2*y100 - 47; + return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t; + } + case 24: { + T t = 2*y100 - 49; + return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 25: { + T t = 2*y100 - 51; + return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 26: { + T t = 2*y100 - 53; + return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t; + } + case 27: { + T t = 2*y100 - 55; + return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t; + } + case 28: { + T t = 2*y100 - 57; + return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t; + } + case 29: { + T t = 2*y100 - 59; + return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t; + } + case 30: { + T t = 2*y100 - 61; + return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t; + } + case 31: { + T t = 2*y100 - 63; + return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 32: { + T t = 2*y100 - 65; + return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 33: { + T t = 2*y100 - 67; + return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 34: { + T t = 2*y100 - 69; + return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t; + } + case 35: { + T t = 2*y100 - 71; + return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 36: { + T t = 2*y100 - 73; + return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 37: { + T t = 2*y100 - 75; + return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 38: { + T t = 2*y100 - 77; + return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 39: { + T t = 2*y100 - 79; + return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t; + } + case 40: { + T t = 2*y100 - 81; + return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 41: { + T t = 2*y100 - 83; + return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 42: { + T t = 2*y100 - 85; + return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 43: { + T t = 2*y100 - 87; + return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 44: { + T t = 2*y100 - 89; + return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t; + } + case 45: { + T t = 2*y100 - 91; + return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 46: { + T t = 2*y100 - 93; + return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t; + } + case 47: { + T t = 2*y100 - 95; + return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 48: { + T t = 2*y100 - 97; + return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t; + } + case 49: { + T t = 2*y100 - 99; + return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t; + } + case 50: { + T t = 2*y100 - 101; + return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t; + } + case 51: { + T t = 2*y100 - 103; + return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 52: { + T t = 2*y100 - 105; + return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t; + } + case 53: { + T t = 2*y100 - 107; + return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t; + } + case 54: { + T t = 2*y100 - 109; + return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 55: { + T t = 2*y100 - 111; + return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 56: { + T t = 2*y100 - 113; + return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 57: { + T t = 2*y100 - 115; + return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 58: { + T t = 2*y100 - 117; + return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 59: { + T t = 2*y100 - 119; + return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 60: { + T t = 2*y100 - 121; + return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 61: { + T t = 2*y100 - 123; + return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 62: { + T t = 2*y100 - 125; + return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 63: { + T t = 2*y100 - 127; + return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 64: { + T t = 2*y100 - 129; + return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t; + } + case 65: { + T t = 2*y100 - 131; + return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t; + } + case 66: { + T t = 2*y100 - 133; + return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 67: { + T t = 2*y100 - 135; + return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t; + } + case 68: { + T t = 2*y100 - 137; + return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t; + } + case 69: { + T t = 2*y100 - 139; + return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t; + } + case 70: { + T t = 2*y100 - 141; + return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t; + } + case 71: { + T t = 2*y100 - 143; + return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t; + } + case 72: { + T t = 2*y100 - 145; + return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t; + } + case 73: { + T t = 2*y100 - 147; + return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t; + } + case 74: { + T t = 2*y100 - 149; + return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 75: { + T t = 2*y100 - 151; + return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t; + } + case 76: { + T t = 2*y100 - 153; + return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 77: { + T t = 2*y100 - 155; + return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 78: { + T t = 2*y100 - 157; + return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 79: { + T t = 2*y100 - 159; + return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 80: { + T t = 2*y100 - 161; + return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 81: { + T t = 2*y100 - 163; + return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 82: { + T t = 2*y100 - 165; + return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 83: { + T t = 2*y100 - 167; + return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 84: { + T t = 2*y100 - 169; + return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 85: { + T t = 2*y100 - 171; + return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t; + } + case 86: { + T t = 2*y100 - 173; + return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t; + } + case 87: { + T t = 2*y100 - 175; + return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 88: { + T t = 2*y100 - 177; + return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t; + } + case 89: { + T t = 2*y100 - 179; + return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 90: { + T t = 2*y100 - 181; + return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t; + } + case 91: { + T t = 2*y100 - 183; + return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 92: { + T t = 2*y100 - 185; + return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t; + } + case 93: { + T t = 2*y100 - 187; + return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 94: { + T t = 2*y100 - 189; + return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 95: { + T t = 2*y100 - 191; + return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t; + } + case 96: { + T t = 2*y100 - 193; + return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t; + } + case 97: { + T t = 2*y100 - 195; + return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t; + } + case 98: { + T t = 2*y100 - 197; + return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t; + } + case 99: { + T t = 2*y100 - 199; + return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t; + } + } + + // we only get here if y = 1, i.e. |x| < 4*eps, in which case + // erfcx is within 1e-15 of 1.. + return 1.; + } + + template + T erfcx(T x) { + // Short-circuits on NaN (returning NaN) + if (x != x) { + return x; + } + + if (x >= 0) { + if (x > T{50}) { // continued-fraction expansion is faster + const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi) + + if (x > T{5e7}) { // 1-term expansion, important to avoid overflow + return ispi / x; + } + + /* 5-term expansion (rely on compiler for CSE), simplified from: + ispi / (x+0.5/(x+1/(x+1.5/(x+2/x)))) */ + return ispi * ((x*x) * (x*x+T{4.5}) + T{2}) / (x * ((x*x) * (x*x+T{5}) + T{3.75})); + } + + // x >= 0 x <= 50 + return erfcx_y100(T{400} / (T{4} + x)); + } + + // x < 0 + if (x < T{-26.7}) { + return POS_INFINITY; + } else if (x < T{-6.1}) { + return T{2} * exp(x * x); + } + + // x < 0 and x >= -6.1 + return T{2} * exp(x * x) - erfcx_y100(T{400} / (T{4} - x)); + } +); // erfcx_string + +const auto airy_ai_string = jiterator_stringify( + template + T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (isinf(x)) { + return NAN; + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * sqrt(-x) / T(3.0) + T(0.25) * T(3.14159265358979323846); + + return T(5.64189583547756286948e-01) / sqrt(sqrt(-x)) * (sin(t) * (T(1.0) + z * z * afn / afd) - cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * sqrt(sqrt(x)) * exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > T(1.11022302462515654042e-16)) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; + } // T airy_ai(T x) +); // airy_ai_string + +const auto bessel_j0_string = jiterator_stringify( + template + T bessel_j0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T RP[] = { + -4.79443220978201773821e+09, + +1.95617491946556577543e+12, + -2.49248344360967716204e+14, + +9.70862251047306323952e+15, + }; + + static const T RQ[] = { + +4.99563147152651017219e+02, + +1.73785401676374683123e+05, + +4.84409658339962045305e+07, + +1.11855537045356834862e+10, + +2.11277520115489217587e+12, + +3.10518229857422583814e+14, + +3.18121955943204943306e+16, + +1.71086294081043136091e+18, + }; + + if (x < T(0)) { + x = -x; + } + + if (x <= T(5.0)) { + if (x < T(0.00001)) { + return T(1.0) - x * x / T(4.0); + } + + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq; + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_j0_forward(T x) +); // bessel_j0_string + +const auto bessel_y0_string = bessel_j0_string + jiterator_stringify( + template + T bessel_y0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T YP[] = { + +1.55924367855235737965e+04, + -1.46639295903971606143e+07, + +5.43526477051876500413e+09, + -9.82136065717911466409e+11, + +8.75906394395366999549e+13, + -3.46628303384729719441e+15, + +4.42733268572569800351e+16, + -1.84950800436986690637e+16, + }; + + static const T YQ[] = { + +1.04128353664259848412e+03, + +6.26107330137134956842e+05, + +2.68919633393814121987e+08, + +8.64002487103935000337e+10, + +2.02979612750105546709e+13, + +3.17157752842975028269e+15, + +2.50596256172653059228e+17, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return NEG_INFINITY; + } + + if (x < T(0.0)) { + NAN; + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return yp / yq + (T(0.636619772367581343075535053490057448) * log(x) * bessel_j0_forward(x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_y0_forward(T x) +); // bessel_y0_string + +const auto bessel_j1_string = jiterator_stringify( + template + T bessel_j1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T RP[] = { + -8.99971225705559398224e+08, + +4.52228297998194034323e+11, + -7.27494245221818276015e+13, + +3.68295732863852883286e+15, + }; + + static const T RQ[] = { + +6.20836478118054335476e+02, + +2.56987256757748830383e+05, + +8.35146791431949253037e+07, + +2.21511595479792499675e+10, + +4.74914122079991414898e+12, + +7.84369607876235854894e+14, + +8.95222336184627338078e+16, + +5.32278620332680085395e+18, + }; + + if (x < T(0.0)) { + return -bessel_j1_forward(-x); + } + + if (x <= T(5.0)) { + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_j1_forward(T x) +); // bessel_j1_string + +const auto bessel_y1_string = bessel_j1_string + jiterator_stringify( + template + T bessel_y1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T YP[] = { + +1.26320474790178026440e+09, + -6.47355876379160291031e+11, + +1.14509511541823727583e+14, + -8.12770255501325109621e+15, + +2.02439475713594898196e+17, + -7.78877196265950026825e+17, + }; + + static const T YQ[] = { + +5.94301592346128195359e+02, + +2.35564092943068577943e+05, + +7.34811944459721705660e+07, + +1.87601316108706159478e+10, + +3.88231277496238566008e+12, + +6.20557727146953693363e+14, + +6.87141087355300489866e+16, + +3.97270608116560655612e+18, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return NEG_INFINITY; + } + + if (x <= T(0.0)) { + return NAN; + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 5; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * log(x) - T(1.0) / x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x); + } // bessel_y1_forward(T x) +); // bessel_y1_string + +const auto chebyshev_polynomial_t_string = jiterator_stringify( + template + T chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (abs(x) < T(1.0))) { + return cos(n * acos(x)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_t_forward(T x, int64_t n) + + template + T chebyshev_polynomial_t_forward(T x, T n) { + return chebyshev_polynomial_t_forward(x, static_cast(n)); + } // chebyshev_polynomial_t_forward(T x, T n) +); // chebyshev_polynomial_t_string + +const auto chebyshev_polynomial_u_string = jiterator_stringify( + template + T chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (sin(acos(x)) != T(0.0)) { + return sin((n + 1) * acos(x)) / sin(acos(x)); + } + + return (n + 1) * cos((n + 1) * acos(x)) / x; + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_u_forward(T x, int64_t n) + + template + T chebyshev_polynomial_u_forward(T x, T n) { + return chebyshev_polynomial_u_forward(x, static_cast(n)); + } // chebyshev_polynomial_u_forward(T x, T n) +); // chebyshev_polynomial_u_string + +const auto chebyshev_polynomial_v_string = jiterator_stringify( + template + T chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0)) { + return T(1.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (sin(acos(x) / T(2.0)) != T(1.0)) { + return cos((n + T(0.5)) * acos(x)) / cos(acos(x) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_v_forward(T x, int64_t n) + + template + T chebyshev_polynomial_v_forward(T x, T n) { + return chebyshev_polynomial_v_forward(x, static_cast(n)); + } // chebyshev_polynomial_v_forward(T x, T n) +); // chebyshev_polynomial_v_string + +const auto chebyshev_polynomial_w_string = jiterator_stringify( + template + T chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 8) && (abs(x) < T(1.0))) { + if (cos(acos(x) / T(2.0)) != T(1.0)) { + return sin((n + T(0.5)) * acos(x)) / sin(acos(x) / T(2.0)); + } + + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x + T(1.0); + } + + T p = T(1.0); + T q = x + x + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; + } // chebyshev_polynomial_w_forward(T x, int64_t n) + + template + T chebyshev_polynomial_w_forward(T x, T n) { + return chebyshev_polynomial_w_forward(x, static_cast(n)); + } // chebyshev_polynomial_w_forward(T x, T n) +); // chebyshev_polynomial_w_string + +const auto hermite_polynomial_h_string = jiterator_stringify( + template + unsigned short getHermitianLimit() { + if (sizeof(T) <= sizeof(float)) { + return 128; + } else if (sizeof(T) <= sizeof(double)) { + return 512; + } else { + return 1024; + } + } + + template + T hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + T p = T(1.0); + T q = x + x; + T r = T(0.0); + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; + } // hermite_polynomial_h_forward(T x, int64_t n) + + template + T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, static_cast(n)); + } // hermite_polynomial_h_forward(T x, T n) +); // hermite_polynomial_h_string + +const auto hermite_polynomial_he_string = jiterator_stringify( + template + unsigned short getHermitianLimit() { + if (sizeof(T) <= sizeof(float)) { + return 128; + } else if (sizeof(T) <= sizeof(double)) { + return 512; + } else { + return 1024; + } + } + + template + T hermite_polynomial_he_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = x * q - k * p; + p = q; + q = r; + } + + return r; + } // hermite_polynomial_he_forward(T x, int64_t n) + + template + T hermite_polynomial_he_forward(T x, T n) { + return hermite_polynomial_he_forward(x, static_cast(n)); + } // hermite_polynomial_he_forward(T x, T n) +); // hermite_polynomial_he_string + +const auto laguerre_polynomial_l_string = jiterator_stringify( + template + T laguerre_polynomial_l_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(0.0)) { + return T(1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return T(1.0) - x; + } + + T p = T(1.0); + T q = T(1.0) - x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; + } // laguerre_polynomial_l_forward(T x, int64_t n) + + template + T laguerre_polynomial_l_forward(T x, T n) { + return laguerre_polynomial_l_forward(x, static_cast(n)); + } // laguerre_polynomial_l_forward(T x, T n) +); // laguerre_polynomial_l_string + +const auto legendre_polynomial_p_string = jiterator_stringify( + template + T legendre_polynomial_p_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = ((k + k + 1) * x * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; + } // legendre_polynomial_p_forward(T x, int64_t n) + + template + T legendre_polynomial_p_forward(T x, T n) { + return legendre_polynomial_p_forward(x, static_cast(n)); + } // legendre_polynomial_p_forward(T x, T n) +); // legendre_polynomial_p_string + +const auto modified_bessel_i0_string = jiterator_stringify( + template + T modified_bessel_i0_forward(T x) { + static const T A[] = { + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + }; + + static const T B[] = { + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + }; + + T p; + T q = 0.0; + + if (abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 30; index++) { + p = q; + q = a; + a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + return exp(abs(x)) * (T(0.5) * (a - p)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index]; + } + + return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)); + } // modified_bessel_i0_forward(T x) +); // modified_bessel_i0_string + +const auto modified_bessel_i1_string = jiterator_stringify( + template + T modified_bessel_i1_forward(T x) { + static const T A[] = { + +2.77791411276104639959e-18, + -2.11142121435816608115e-17, + +1.55363195773620046921e-16, + -1.10559694773538630805e-15, + +7.60068429473540693410e-15, + -5.04218550472791168711e-14, + +3.22379336594557470981e-13, + -1.98397439776494371520e-12, + +1.17361862988909016308e-11, + -6.66348972350202774223e-11, + +3.62559028155211703701e-10, + -1.88724975172282928790e-09, + +9.38153738649577178388e-09, + -4.44505912879632808065e-08, + +2.00329475355213526229e-07, + -8.56872026469545474066e-07, + +3.47025130813767847674e-06, + -1.32731636560394358279e-05, + +4.78156510755005422638e-05, + -1.61760815825896745588e-04, + +5.12285956168575772895e-04, + -1.51357245063125314899e-03, + +4.15642294431288815669e-03, + -1.05640848946261981558e-02, + +2.47264490306265168283e-02, + -5.29459812080949914269e-02, + +1.02643658689847095384e-01, + -1.76416518357834055153e-01, + +2.52587186443633654823e-01, + }; + + static const T B[] = { + +7.51729631084210481353e-18, + +4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + +2.96262899764595013876e-16, + +3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + +1.04202769841288027642e-14, + +4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + +2.03562854414708950722e-12, + +1.41258074366137813316e-11, + +3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + +7.78576235018280120474e-01, + }; + + T p; + T q = 0.0; + + if (abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 29; index++) { + p = q; + q = a; + a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + if (x < T(0.0)) { + return -(T(0.5) * (a - p) * abs(x) * exp(abs(x))); + } + + return T(0.5) * (a - p) * abs(x) * exp(abs(x)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index]; + } + + if (x < T(0.0)) { + return -(exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x))); + } + + return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)); + } // modified_bessel_i1_forward(T x) +); // modified_bessel_i1_string + +const auto modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify( + template + T modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return T(0.5) * (a - p) - log(0.5 * x) * modified_bessel_i0_forward(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return exp(-x) * (T(0.5) * (b - p)) / sqrt(x); + } // modified_bessel_k0_forward(T x) +); // modified_bessel_k0_string + +const auto scaled_modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify( + template + T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - log(T(0.5) * x) * modified_bessel_i0_forward(x)) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / sqrt(x); + } // T scaled_modified_bessel_k0_forward(T x) +); // scaled_modified_bessel_k0_string + +const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( + template + T modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x; + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return exp(-x) * (T(0.5) * (b - p)) / sqrt(x); + } // modified_bessel_k1_forward(T x) +); // modified_bessel_k1_string + +const auto scaled_modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify( + template + T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * exp(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / sqrt(x)); + } // T scaled_modified_bessel_k1_forward(T x) +); // scaled_modified_bessel_k1_string + +const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + return cos(n * acos(x + x - T(1.0))); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_t_forward(T x, T n) { + return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_t_forward(T x, T n) +); // shifted_chebyshev_polynomial_t_string + +const auto shifted_chebyshev_polynomial_u_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + if (sin(acos(x + x - T(1.0))) != T(0.0)) { + return sin((n + 1) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0))); + } + + return (n + 1) * cos((n + 1) * acos(x + x - T(1.0))) / (x + x - T(1.0)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_u_forward(T x, T n) { + return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_u_forward(T x, T n) +); // shifted_chebyshev_polynomial_u_string + +const auto shifted_chebyshev_polynomial_v_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) { + if (sin(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return cos(((n) + T(0.5)) * acos(x + x - T(1.0))) / cos(acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_v_forward(T x, T n) { + return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_v_forward(T x, T n) +); // shifted_chebyshev_polynomial_v_string + +const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify( + template + T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 4) && (abs(x + x - T(1.0)) < T(1.0))) { + if (cos(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return sin((n + T(0.5)) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; + } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + + template + T shifted_chebyshev_polynomial_w_forward(T x, T n) { + return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); + } // shifted_chebyshev_polynomial_w_forward(T x, T n) +); // shifted_chebyshev_polynomial_w_string + +const auto spherical_bessel_j0_string = jiterator_stringify( + template + T spherical_bessel_j0_forward(T x) { + if (isinf(x)) { + return T(0.0); + } + + if (abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return sin(x) / x; + } // T spherical_bessel_j0_forward(T x) +); // spherical_bessel_j0_string + +#else // !AT_USE_JITERATOR() -- kernels must be precompiled + +template +static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) { + scalar_t a = ::abs(a_in); + scalar_t b = ::abs(b_in); + while (a != 0) { + scalar_t c = a; + a = b % a; + b = c; + } + return b; +} + +/* + * For licensing information, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +template +static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma + using accscalar_t = at::acc_type; + static const double PI_f64 = 3.14159265358979323846; + const accscalar_t PSI_10 = 2.25175258906672110764; + const accscalar_t A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + accscalar_t x = static_cast(in); + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(static_cast(INFINITY), -x); + } + + bool x_is_integer = x == ::trunc(x); + accscalar_t result = 0; + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return static_cast(NAN); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = ::modf(static_cast(x), &q); + result = static_cast(- PI_f64 / ::tan(PI_f64 * r)); + x = 1 - x; + } + + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return static_cast(result + PSI_10); + } + + accscalar_t y = 0; + if (x < 1.0e17) { + accscalar_t z = 1 / (x * x); + + accscalar_t polevl_result = 0; + for (int i = 0; i <= 6; i++) { + polevl_result = polevl_result * z + A[i]; + } + y = z * polevl_result; + } + + return static_cast(::log(x) - (static_cast(0.5) / x) - y + result); +} + +template +static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) { + using accscalar_t = at::acc_type; + const accscalar_t PI = 3.14159265358979323846; + accscalar_t x = static_cast(in); + accscalar_t sign = +1; + accscalar_t result = 0; + if (x < 0.5f) { + sign = -1; + accscalar_t sin_pi_x = ::sin(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const accscalar_t one = static_cast(1); + const accscalar_t ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (one/6 - ixx * (one/30 - ixx * (one/42)))) / x; + return static_cast(sign * result); +} + +/* + * For licensing information and documentation, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +template +static inline C10_HOST_DEVICE scalar_t +chbevl(scalar_t _x, const scalar_t array[], size_t len) { + static_assert(!std::is_same() && !std::is_same(), "don't instantiate with low precision type"); + + scalar_t b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (size_t i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = _x * b1 - b2 + array[i]; + } + + return (0.5 * (b0 - b2)); +} + +/* + * For licensing information and documentation, please refer to the cpu implementation located in "ATen/native/Math.h". + */ +template +C10_HOST_DEVICE inline std::tuple chebyshev_coefficients_i0e_A() { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T coefficients[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + return std::make_tuple(coefficients, 30); +} + +template +C10_HOST_DEVICE inline std::tuple chebyshev_coefficients_i0e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + static const T coefficients[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return std::make_tuple(coefficients, 25); +} + +template +static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) { + static_assert(!std::is_same() && !std::is_same(), "don't instantiate with low precision type"); + // Upcast input for numerical accuracy purposes + // Needed for accurate results if input is bfloat16 or float16 + scalar_t x = ::abs(_x); + + if (x <= scalar_t{8.0}) { + auto [A, len] = chebyshev_coefficients_i0e_A(); + scalar_t y = (x / scalar_t{2.0}) - scalar_t{2.0}; + return (::exp(x) * chbevl(y, A, len)); + } + + auto [B, len] = chebyshev_coefficients_i0e_B(); + return (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x)); +} + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + + return std::make_tuple(coefficients, 29); +} + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + return std::make_tuple(coeff, 17); +}; + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + return std::make_tuple(coefficients, 25); +} + +template +C10_HOST_DEVICE inline + typename std::enable_if_t, std::tuple> + chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + return std::make_tuple(coeff, 7); +}; + +template +static inline C10_HOST_DEVICE scalar_t calc_i1(scalar_t _x) { + const auto x = ::abs(_x); + if (x <= scalar_t{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + scalar_t y = x / scalar_t{2.0} - scalar_t{2.0}; + const scalar_t out = ::exp(x) * x * chbevl(y, A, len); + return (_x < scalar_t{0.0}) ? -out : out; + } + + auto [B, len] = chebyshev_coefficients_i1e_B(); + const scalar_t out = (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len)) / ::sqrt(x); + return (_x < scalar_t{0.0}) ? -out : out; +} + +template +static inline C10_HOST_DEVICE scalar_t calc_i1e(scalar_t _x) { + const auto x = ::abs(_x); + if (x <= scalar_t{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + const scalar_t y = x / scalar_t{2.0} - scalar_t{2.0}; + const scalar_t out = chbevl(y, A, len) * x; + return (_x < scalar_t{0.0}) ? -out : out; + } + + auto [B, len] = chebyshev_coefficients_i1e_B(); + const scalar_t out = chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x); + return (_x < scalar_t{0.0}) ? -out : out; +} + +#endif // AT_USE_JITERATOR() (this closes the "else" branch of a if/else preprocessor directive) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d29ba35393a08617125f543365e64a51e425a9c8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh @@ -0,0 +1,682 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// References: +// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/ + +namespace at::native::memory { + +namespace detail { + +// What does the `static_unroll` do? +// +// We want to do something like: +// +// using args_t = typename traits::ArgsTuple; +// args_t args; +// #pragma unroll +// for (int i = 0; i < traits::arity; i++) { +// std::get(args) = .... +// } +// +// but unfortunately the above code does not work because +// the template argument has to be a compile time constant +// so `static_unroll` is created to simulate `#pragma unroll` +// using template metaprogramming. + +template typename func, int end, int current=0> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args&&... args) { + func::apply(std::forward(args)...); + static_unroll::with_args(args...); + } +}; + +template typename func, int end> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args... /*args*/) {} +}; + +// helper structs to be used with static_unroll to load arguments +// one by one + +template +struct vectorized_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, int idx, int block_work_size) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + auto ptr = reinterpret_cast(self.data[arg_index + 1]) + block_work_size * idx; + auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get(args[thread_unroll_idx]); }; + self.load_single_arg(args_accessor, ptr); + } +}; + +#ifdef USE_ROCM +// Templated version of vectorized load helper. +// It can be used on heterogeneous input tensor element types. +template +struct vectorized_templated_load_helper { + template + static __device__ void apply(policy_t& self, args_t* args, int idx) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + + // Delay pointer arithmetic to the policy loader where we know the actual + // type of the current argument. + char* ptr = (self.data[arg_index + 1]); + auto args_accessor = [&args] __device__(int thread_unroll_idx) -> arg_t& { + return std::get(args[thread_unroll_idx]); + }; + self.template load_single_arg(args_accessor, ptr, idx); + } +}; +#endif + +template +struct unroll_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index], arg_index); + } +}; + +template +struct multi_outputs_store_helper { + template + C10_HOST_DEVICE static void apply( + const data_t& data, + const offsets_t& offsets, + thrust::tuple ret) { + using T = typename thrust::tuple_element>::type; + T *to = reinterpret_cast(data[current]) + offsets[current]; + *to = thrust::get(ret); + } +}; + +} // namespace detail + +struct LoadWithoutCast { + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + return c10::load(reinterpret_cast(base_ptr) + offset); + } +}; + +template +struct LoadWithCast { + using array_t = std::array(N, 1)>; + using size_array_t = std::array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + LoadWithCast(const TensorIteratorBase& iter) { + CUDA_KERNEL_ASSERT(iter.ninputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i + iter.noutputs()); + element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs())); + } + } + + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + void *ptr = base_ptr + element_sizes[arg] * offset; + return c10::fetch_and_cast(dtypes[arg], ptr); + } +}; + +struct StoreWithoutCast { + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + *(reinterpret_cast(base_ptr) + offset) = value; + } +}; + +template +struct StoreWithCast { + using array_t = std::array(N, 1)>; + using size_array_t = std::array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + StoreWithCast(const TensorIteratorBase& iter) { + CUDA_KERNEL_ASSERT(iter.noutputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i); + element_sizes[i] = c10::elementSize(iter.dtype(i)); + } + } + + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + void *ptr = base_ptr + element_sizes[arg] * offset; + c10::cast_and_store(dtypes[arg], ptr, value); + } +}; + +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; +}; + +template +__device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { + using vec_t = aligned_vector; + auto *from = reinterpret_cast(base_ptr); +#if defined(USE_ROCM) && defined(__gfx942__) + using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int; + if constexpr (sizeof(vec_t) == sizeof(int)) { + union { + vec_t v; + int i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } + else if constexpr (sizeof(vec_t) == sizeof(long)) { + union { + vec_t v; + long i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } + else if constexpr (sizeof(vec_t) == sizeof(longx2)) { + union { + vec_t v; + longx2 i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } +#endif + return from[offset]; +} + +template +__device__ aligned_vector load_vector(const bool *base_ptr, uint32_t offset) { + // See NOTE [Loading boolean values] + auto tmp = load_vector(reinterpret_cast(base_ptr), offset); + aligned_vector ret; + for (int i = 0; i < vec_size; ++i) { + ret.val[i] = bool(tmp.val[i]); + } + return ret; +} + +namespace policies { + +template < + int num_threads, + typename data_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + int elems_per_thread, + int num_outputs = 1> +struct unroll_base { + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + loader_t loader; + storer_t storer; + static constexpr int tws = elems_per_thread; + static constexpr int block_work_size = elems_per_thread * num_threads; + + __device__ unroll_base( + data_t data, + int remaining, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) + : data(data), + remaining(remaining), + input_offset_calculator(ic), + output_offset_calculator(oc), + loader(l), + storer(s) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem * num_threads) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size_v; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args( + *this, args, offset, loader, i, num_outputs); + thread_idx += num_threads; + } + } + } + + template + __device__ inline void store(scalar_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + int offset = output_offset_calculator.get(linear_idx)[0]; + storer.store(from[i], data[0], offset); + thread_idx += num_threads; + } + } + } +}; + +// Utility type for all users of unroll that extract the num_threads value from +// the caller scope. +template < + typename data_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + int elems_per_thread, + int num_outputs = 1> +using unroll = unroll_base< + num_threads(), + data_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + elems_per_thread, + num_outputs>; + +template // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized { + + static_assert(elems_per_thread % vec_size == 0, "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = elems_per_thread / vec_size; + static constexpr int tws = elems_per_thread; + + data_t data; + + __device__ vectorized(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load_single_arg(accessor_t to, scalar_t *from) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + auto v = load_vector(from, index); + #pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = v.val[j]; + } + } + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size_v; + detail::static_unroll::with_args(*this, args, idx, elems_per_thread * num_threads()); + } + + template + __device__ inline void store(scalar_t *from, int idx) { + using vec_t = aligned_vector; + scalar_t *to = reinterpret_cast(data[0]) + elems_per_thread * num_threads() * idx; + vec_t *to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; + +#ifdef USE_ROCM +// This is similar to vectorized policy above, but this one supports +// heterogenous input tensor types as templated parameters. +// Its use should be limited to frequently used heterogeneous data types +// as each instantiation will generate a separate kernel, leading to code +// bloating if applied to all combinations supported in PyTorch. Assumption: all +// tensors are contiguous, that is: stride == sizeof(type) for all tensors. +template < + int vec_size, + typename data_t, + int elems_per_thread, + int num_threads, + typename CastToT, + typename... CastFromTs> // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized_templated { + static_assert( + elems_per_thread % vec_size == 0, + "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = elems_per_thread / vec_size; + static constexpr int tws = elems_per_thread; + static constexpr int block_work_size = elems_per_thread * num_threads; + data_t data; + + __device__ vectorized_templated(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load_single_arg(accessor_t to, char* ptr, int idx) { + // extract the arg_index-th input tensor element type from the + // variadic template argument. + using CastFromT = + std::tuple_element_t>; + // Delayed pointer arithmetic from the caller: this is the place + // where we know the type of the argument. + CastFromT* block_ptr = + reinterpret_cast(ptr) + block_work_size * idx; + int thread_idx = threadIdx.x; +#pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + auto v = load_vector(block_ptr, index); +#pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = c10::convert(v.val[j]); + } + } + } + + template + __device__ inline void load(args_t* args, int idx) { + constexpr int arity = std::tuple_size::value; + detail::static_unroll:: + with_args(*this, args, idx); + } + + // Assume for now that from (temporary array per thread) is of the same + // type as to (destination tensor), which is the case for + // float(float,bfloat16) and functor add on float(float,float). + template + __device__ inline void store(scalar_t* from, int idx) { + using vec_t = aligned_vector; + CastToT* to = reinterpret_cast(data[0]) + block_work_size * idx; + vec_t* to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; +#pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; +#endif + +template +struct multi_outputs_unroll { + //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct + //we don't use inheritance because of compiler bug in cuda 10.2+ + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + LoadWithoutCast loader; + StoreWithoutCast storer; + static constexpr int tws = thread_work_size(); + + __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc): + data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size_v; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads(); + } + } + + + template + __device__ inline void store(return_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= this->remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offsets = this->output_offset_calculator.get(linear_idx); + memory::detail::static_unroll::with_args(this->data, offsets, from[i]); + thread_idx += num_threads(); + } + } +}; + +} // namespace policies + +// This is only used in host, but we will wrap this into some templates +// which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE +// in order to compile +template +inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec2_alignment = std::alignment_of_v>; + constexpr int vec4_alignment = std::alignment_of_v>; + constexpr int vec8_alignment = std::alignment_of_v>; +#ifdef USE_ROCM + constexpr int vec16_alignment = std::alignment_of_v>; + constexpr int type_size = sizeof(scalar_t); + if (type_size == 1 && (address % vec16_alignment == 0)) { + return 16; + } else if (type_size <= 2 && (address % vec8_alignment == 0)) { + return 8; + } else +#else + if (address % vec8_alignment == 0) { + return 8; + } else +#endif + if (address % vec4_alignment == 0) { + return 4; + } else if (address % vec2_alignment == 0) { + return 2; + } + return 1; +} + +template +inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) { + return can_vectorize_up_to(static_cast(pointer)); +} + +template +struct can_vectorize_up_to_helper { + template + static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits /*_*/) { + using arg_t = typename traits::template arg::type; + // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + result = std::min(result, can_vectorize_up_to(pointers[i + 1])); + } +}; + +template +inline int can_vectorize_up_to(array_t pointers) { + using traits = function_traits; + using return_t = typename traits::result_type; + constexpr int arity = traits::arity; + int result = can_vectorize_up_to(pointers[0]); + // We need to get the type for each argument of `func_t`, this can only + // be done at compile time. + detail::static_unroll::with_args(result, pointers, traits()); + return result; +} + + + +template +__inline__ size_t get_alignment(T ptr_or_size) { + auto val = reinterpret_cast(ptr_or_size); + if (val % 16 == 0) { + return 16; + } else if (val % 8 == 0) { + return 8; + } else if (val % 4 == 0) { + return 4; + } else if (val % 2 == 0) { + return 2; + } else { + return 1; + } +} + +template <> +__inline__ size_t get_alignment(size_t size) { + return get_alignment(reinterpret_cast(size)); +} + +template +inline constexpr bool dependent_bool_value = Value; + +template +inline constexpr bool dependent_false = dependent_bool_value; + +template +union Vec; + +template <> +union Vec<4> { + uint16_t u16[2]; + uint32_t u32, as_scalar; + float f32; +}; + +template <> +union Vec<8> { + uint16_t u16[4]; + uint32_t u32[2]; + uint64_t u64, as_scalar; + float f32[2]; +}; + +template <> +union alignas(16) Vec<16> { + uint16_t u16[8]; + uint32_t u32[4]; + uint64_t u64[2]; + uint4 u128, as_scalar; + float f32[4]; +}; + +template +__device__ __inline__ Vec ld_vec(const T* addr) { + Vec vec; + if constexpr (Alignment == 16) { +#if defined(USE_ROCM) + vec.u128 = *reinterpret_cast(addr); + } else if constexpr (Alignment == 8) { + vec.u64 = *reinterpret_cast(addr); + } else if constexpr (Alignment == 4) { + vec.u32 = *reinterpret_cast(addr); +#else + asm("ld.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(vec.u32[0]), "=r"(vec.u32[1]), "=r"(vec.u32[2]), "=r"(vec.u32[3]) + : "l"(addr) + : "memory"); + } else if constexpr (Alignment == 8) { + asm("ld.global.v2.u32 {%0,%1}, [%2];" + : "=r"(vec.u32[0]), "=r"(vec.u32[1]) + : "l"(addr) + : "memory"); + } else if constexpr (Alignment == 4) { + asm("ld.global.u32 %0, [%1];" : "=r"(vec.u32) : "l"(addr) : "memory"); +#endif + } else { + static_assert(dependent_false); + } + return vec; +} + +template +__device__ __inline__ void st_vec(T* addr, const Vec& vec) { + if constexpr (Alignment == 16) { +#if defined(USE_ROCM) + reinterpret_cast(addr)[0] = vec.u64[0]; + reinterpret_cast(addr)[1] = vec.u64[1]; + } else if constexpr (Alignment == 8) { + *reinterpret_cast(addr) = vec.u64; + } else if constexpr (Alignment == 4) { + *reinterpret_cast(addr) = vec.u32; +#else + asm("st.global.v4.u32 [%0], {%1,%2,%3,%4};" + : + : "l"(addr), + "r"(vec.u32[0]), + "r"(vec.u32[1]), + "r"(vec.u32[2]), + "r"(vec.u32[3]) + : "memory"); + } else if constexpr (Alignment == 8) { + asm("st.global.v2.u32 [%0], {%1,%2};" + : + : "l"(addr), "r"(vec.u32[0]), "r"(vec.u32[1]) + : "memory"); + } else if constexpr (Alignment == 4) { + asm("st.global.u32 [%0], %1;" : : "l"(addr), "r"(vec.u32) : "memory"); +#endif + } else { + static_assert(dependent_false); + } +} + + + +} // namespace at::native::memory diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MiscUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MiscUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..f733f3a38099c2b2a6e11bd98a3f8a95c6e1b10a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MiscUtils.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include +#include + + +namespace at::native { + +static inline int cuda_int_cast(int64_t value, const char* varname) { + auto result = static_cast(value); + TORCH_CHECK(static_cast(result) == value, + "cuda_int_cast: The value of ", varname, "(", (long long)value, + ") is too large to fit into a int (", sizeof(int), " bytes)"); + return result; +} + +// Creates an array of size elements of type T, backed by pinned memory +// wrapped in a Storage +template +static inline Storage pin_memory(int64_t size) { + auto* allocator = cuda::getPinnedMemoryAllocator(); + int64_t adjusted_size = size * sizeof(T); + return Storage( + Storage::use_byte_size_t(), + adjusted_size, + allocator, + /*resizable=*/false); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2fe431f778b1a299e104e0bd79499096fc47a4dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh @@ -0,0 +1,382 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +static constexpr int64_t kILP = 4; +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kBlockSize = 512; + +// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` +// TensorListMetadata has to be < 4KB - the limit for kernel launch argument +static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; +static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 72, + 60}; + +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (kILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store( + T* dst, + T* src, + int64_t dst_offset, + int64_t src_offset) { + using LT = at::native::memory::aligned_vector; + ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; +} + +template +struct TensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +struct TensorListScalarListMetadata { + const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; + scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; +}; + +// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of +// 4kb with `c10::complex` +template <> +struct TensorListScalarListMetadata, 1> { + const void* addresses[1] + [depth_to_max_tensors_scalarlist_of_complex_double[0]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; + int block_to_chunk[depth_to_max_blocks[1 - 1]]; +}; + +template <> +struct TensorListScalarListMetadata, 2> { + const void* addresses[2] + [depth_to_max_tensors_scalarlist_of_complex_double[1]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; + int block_to_chunk[depth_to_max_blocks[2 - 1]]; +}; + +// NOTE(crcrpar): This is a conservative resolution to handle `state_steps` +// whose each element is `at::Tensor` of 1 element representing the number of +// `step`s called so far. +template +struct FusedOptimizerTensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +C10_LAUNCH_BOUNDS_1(kBlockSize) +__global__ void multi_tensor_apply_kernel( + T tensorListMeta, + U callable, + ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however + // it likes. + callable(kChunkSize, tensorListMeta, args...); +} + +} // namespace + +// multi_tensor_apply enables horizontal fusion across lists of tensors. +// For example, whereas you once had a for-loop of a + b = c, where a, b, +// and c are individual tensors in lists as, bs, and cs, you can now with +// fewer kernel launches compute as + bs = cs. +// +// You can also imagine bs to be a scalar list vs a tensor list. +// +// The function below takes in tensor lists, scalars, and a callable and +// chunks up the computation to launch as few kernels as possible by iterating +// through every "chunk" in every tensor (thus the nested for loops). In the +// simplest case, everything gets bundled into just one kernel launch, but +// due to blocksize constraints, we may need to launch multiple kernels. +// Each kernel launch is defined by one tensorListMeta construct, which we +// use to track and reset the necessary metadata for each launch. +template +void multi_tensor_apply( + std::vector>& tensor_lists, + at::ArrayRef scalars, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + using scalar_vals_t = typename T::opmath_t; + TensorListScalarListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // now we enter [chunking territory]. + // we will launch a kernel when EITHER the blocks get filled up OR + // the tensors get filled up. There will always be at least one block + // per tensor since the zero-sized ones will not enter the loop, so + // the nested forloop within represents iterating through the chunks + // of a single tensor. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + // a tensor is not considered full unless all its chunks have been + // processed + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + // all chunks have already been handled in the kernel + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { // blocks were full and tensor chunks remain + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.scalar_vals[0] = + tensorListMeta.scalar_vals[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // note: [finishing what we started] + // if there's remaining work to be done but the tensors/blocks aren't full + // yet we are at the end, submit the kernel to do the work! + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply( + std::vector>& tensor_lists, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + TensorListMetadata tensorListMeta; + tensorListMeta.start_tensor_this_launch = 0; + + int loc_block_info = 0; + int loc_tensor_info = 0; + int processed = 0; + + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + processed++; + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // see note: [chunking territory]. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + tensorListMeta.start_tensor_this_launch = processed; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + tensorListMeta.start_tensor_this_launch = processed - 1; + } + } + } + } + + // see note: [finishing what we started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply_for_fused_optimizer( + std::vector>& tensor_lists, + at::TensorList state_steps, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth"); + const auto num_tensors = tensor_lists[0].size(); + FusedOptimizerTensorListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (const auto& tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + tensorListMeta.state_steps_addresses[loc_tensor_info] = + state_steps[tensor_index].const_data_ptr(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][tensor_index].numel(); + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][tensor_index].const_data_ptr(); + } + loc_tensor_info++; + + // see above note: [chunking territory] + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + for (const auto& chunk : c10::irange(chunks)) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const auto tensor_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1]; + + if (tensor_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.state_steps_addresses[0] = + tensorListMeta.state_steps_addresses[loc_tensor_info - 1]; + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // see above note: [finishing what we've started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Normalization.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Normalization.cuh new file mode 100644 index 0000000000000000000000000000000000000000..088aa517aa23a27739faaf65c8c0ad5904c66e65 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Normalization.cuh @@ -0,0 +1,1742 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +namespace at::native { + +// The maximum number of threads in a block +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; +#else +constexpr int MAX_BLOCK_SIZE = 512; +#endif + +constexpr unsigned MAX_GRID_SIZE = 65535u; + +// Number of threads in a block given an input size up to MAX_BLOCK_SIZE +static int getNumThreads(int nElem) { +#if defined(USE_ROCM) + int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; +#else + int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; +#endif + for (int i = 0; i != 5; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +// Returns the index of the most significant 1 bit in `val`. +__device__ __forceinline__ int getMSB(int val) { + return 31 - __clz(val); +} + +template +struct Float2 { + accscalar_t v1, v2; + __device__ Float2() {} + __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} + __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} + __device__ Float2& operator+=(const Float2& a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } + __device__ friend Float2 operator+(Float2 a, const Float2& b) { + a += b; + return a; + } +}; + +template +struct GradOp { + __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g) + : mean(m), input(i), grad_output(g) {} + __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { + accscalar_t g = grad_output[batch][plane][n]; + accscalar_t c = static_cast(input[batch][plane][n]) - mean; + return Float2(g, g * c); + } + const accscalar_t mean; + const PTA& input; + const PTA& grad_output; +}; + +template +struct SumReduceOp { + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +}; + +template +struct SumReduceOp> { + using acc_t = Float2; + + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)}; + } +}; + +// Sum across (batch, x/y/z) applying Op() pointwise +// this works by first having each thread sum it's part +// of the data. Then there is a double-shuffling reduction. +// First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its +// data to the "warp leader", who writes its value into shared memory. +// Then a single warp reads the remaining (at most C10_WARP_SIZE) items +// and reduces them using another warpSum. +// The implicit assumption is that there are no more +// than C10_WARP_SIZE**2 threads. +template +__device__ scalar_t reduce(Op op, PTA tensor, int plane) { + // first the reductions each thread does separately + scalar_t sum = static_cast(0); + for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { + sum += op(batch, plane, x); + } + } + __shared__ scalar_t shared[C10_WARP_SIZE]; + SumReduceOp reduce_op; + sum = cuda_utils::BlockReduce, cuda_utils::Block2D>(sum, reduce_op, 0, shared); + if (threadIdx.x == 0 && threadIdx.y == 0) { + shared[0] = sum; + } + __syncthreads(); + // Everyone picks it up, should be broadcast into the whole grad_input + return shared[0]; +} + +constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency +constexpr int ELEMENTS_PER_THREAD = 16; +constexpr int OPTIMAL_TILE_W = 32; +constexpr int MAX_H_BLOCK = 128; + +__host__ void flexible_launch_configs( + const int reduction, + const int stride, + dim3 &block, + dim3 &grid, + const bool coop_flag = false) { + int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W); + int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)), + MAX_BLOCK_SIZE / block_x); + if (block_x * block_y != MAX_BLOCK_SIZE) { + block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y); + } + + int grid_x = at::ceil_div(stride, block_x); + int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); + if (coop_flag) { + // it's not worth having a grid reduction if the reduction dimension is not big enough + grid_y = grid_y < 8 ? 1 : grid_y; + } + + block.x = block_x; + block.y = block_y; + block.z = 1; + grid.x = grid_x; + grid.y = grid_y; + grid.z = 1; +} + +template +__device__ __forceinline__ void welford_merge_element(C& count, + T& mean, + T& m2n, + const C& count_new, + const T& mean_new, + const T& m2n_new) { + T factor = T(1.0) / ::max(1, (count + count_new)); + T delta0 = mean - mean_new; + mean = (mean_new * count_new + mean * count) * factor; + m2n += m2n_new + delta0 * delta0 * count_new * count * factor; + count += count_new; +} + +// merge mean/m2n among threadIdx.y within block +template +__device__ __forceinline__ void welford_merge_block_vertical(C& count, + T& mean, + T& m2n, + C* shmem_count, + T* shmem_mean, + T* shmem_m2n) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_mean[address_base] = mean; + shmem_m2n[address_base] = m2n; + shmem_count[address_base] = count; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + // read shared memory back to register for reduction + auto count_new = shmem_count[address]; + auto mean_new = shmem_mean[address]; + auto m2n_new = shmem_m2n[address]; + + welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); + } + } +} + +template +__global__ void batch_norm_transform_input_kernel( + const GenericPackedTensorAccessor input, + GenericPackedTensorAccessor output, + const GenericPackedTensorAccessor, 1, RestrictPtrTraits, index_t> mean_, + const GenericPackedTensorAccessor, 1, RestrictPtrTraits, index_t> var_or_invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor bias, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast(weight[plane]) : static_cast(1); + stat_accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[plane]) : static_cast(0); + stat_accscalar_t mean = static_cast(mean_[plane]); + stat_accscalar_t invstd; + if (train) { + invstd = var_or_invstd[plane]; + } else { + invstd = static_cast(1) / device_sqrt(static_cast(var_or_invstd[plane]) + epsilon); + } + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto o = output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + o[feature] = static_cast(gamma * (i[feature] - mean) * invstd + beta); + } + } +} + +struct InvStd { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + T invstd = 0; + if (var != static_cast(0) || epsilon != static_cast(0)) { + invstd = static_cast(1) / device_sqrt(var + epsilon); + } + return invstd; + } +}; + +struct Var { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + return var; + } +}; + +template +__global__ void batch_norm_collect_statistics_kernel( + const GenericPackedTensorAccessor input, + const stat_accscalar_t epsilon, + const stat_accscalar_t momentum, + GenericPackedTensorAccessor save_mean, + GenericPackedTensorAccessor save_transformed_var) { + + __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; + + int plane = blockIdx.x; + int N = input.size(0) * input.size(2); + int tid = threadIdx.x + threadIdx.y * blockDim.x; + + // Compute the mean and variance across (batch, x/y/z) + // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + // and the parallel algorithm on the same page. + // We use two shuffles to reduce across the entire block. + // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. + stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE]; + + // first the reductions each thread does separately + stat_accscalar_t avg = 0; + stat_accscalar_t var_n = 0; + int n = 0; + for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { + stat_accscalar_t v = input[batch][plane][x]; + stat_accscalar_t d1 = v - avg; + n++; + avg += d1 / n; + var_n += d1 * (v - avg); + } + } + + // first warpSum to get one value per thread to + // one value per warp + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // this writes each warps item into shared memory + // there are at most C10_WARP_SIZE items left because + // there are at most C10_WARP_SIZE**2 threads at the beginning + __syncthreads(); + if (tid % C10_WARP_SIZE == 0) { + shared_n[tid / C10_WARP_SIZE] = n; + shared_avg_var[tid / C10_WARP_SIZE * 2] = avg; + shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n; + } + __syncthreads(); + // now have a second warpSum to reduce the intermediate values + // from shared memory to a single number. The very first + // thread writes it to shared memory. + + if (tid < C10_WARP_SIZE) { + n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0); + avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); + var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); + } + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // Save the mean, variance, and moving averages + if (tid == 0) { + if (save_mean.data() != NULL) { + save_mean[plane] = avg; + } + if (save_transformed_var.data() != NULL) { + save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon); + } + } + +} + +template +__global__ void batch_norm_backward_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor grad_input, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor running_mean, + const GenericPackedTensorAccessor running_var, + const GenericPackedTensorAccessor save_mean, + const GenericPackedTensorAccessor save_invstd, + bool train, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + index_t N = grad_output.size(0) * grad_output.size(2); + + stat_accscalar_t mean, invstd; + if (train) { + mean = save_mean[plane]; + invstd = save_invstd[plane]; + } else { + mean = static_cast(running_mean[plane]); + invstd = static_cast(1) / device_sqrt(static_cast(running_var[plane]) + epsilon); + } + + stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + stat_accscalar_t norm = stat_accscalar_t(1) / N; + + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(grad_output) + // 2. DotProduct(input - mean, grad_output) + GradOp> g(mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + stat_accscalar_t grad_output_sum = res.v1; + stat_accscalar_t dot_p = res.v2; + + stat_accscalar_t grad_mean = grad_output_sum * norm; + stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd; + stat_accscalar_t grad_scale = invstd * weight_val; + + if (grad_input.data() != NULL) { + for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) { + input_scalar_t go = grad_output[batch][plane][x]; + if (train) { + stat_accscalar_t inp = input[batch][plane][x]; + stat_accscalar_t proj = (inp - mean) * proj_scale; + grad_input[batch][plane][x] = static_cast((go - proj - grad_mean) * grad_scale); + } else { + grad_input[batch][plane][x] = static_cast(go * grad_scale); + } + } + } + } + + if (grad_weight.size(0) > 0) { + if (threadIdx.x == 0) { + grad_weight[plane] = static_cast(dot_p * invstd); + } + } + + if (grad_bias.size(0) > 0) { + if (threadIdx.x == 0) { + grad_bias[plane] = static_cast(grad_output_sum); + } + } +} + +template +__global__ void batch_norm_reduce_statistics_kernel( + const GenericPackedTensorAccessor vec_mean, + const GenericPackedTensorAccessor vec_invstd, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, + const accscalar_t epsilon, + const accscalar_t momentum, + const GenericPackedTensorAccessor counts) { + + int feature_size = vec_mean.size(1); + int world_size = vec_mean.size(0); + + int bid = blockIdx.x; + int tid = threadIdx.x; + + // first the reductions each thread does separately + for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) { + accscalar_t avg = 0; + accscalar_t var_n = 0; + index_t n = 0; + for (int j = 0; j < world_size; j++) { + scalar_t count = counts[j]; + accscalar_t m = vec_mean[j][i]; + accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]); + v = (v * v - epsilon) * count; + accscalar_t factor = 1.0 / (n + count); + var_n += v + (avg - m) * (avg - m) * n * count * factor; + avg = n * factor * avg + count * factor * m; + n += count; + } + mean[i] = avg; + invstd[i] = static_cast(1) / device_sqrt(var_n / n + epsilon); + if (running_mean.data() != NULL) { + running_mean[i] = static_cast((1 - momentum) * running_mean[i] + momentum * avg); + } + accscalar_t unbiasedVar = var_n / (n - 1); + if (running_var.data() != NULL) { + running_var[i] = static_cast((1 - momentum) * running_var[i] + momentum * unbiasedVar); + } + } + +} + +template +__global__ void batch_norm_backward_reduce_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor sum_dy, + GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias) { + + index_t plane = blockIdx.x; + + stat_accscalar_t r_mean = mean[plane]; + stat_accscalar_t factor = invstd[plane]; + + GradOp> g(r_mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + if (threadIdx.x == 0) { + if (grad_weight.size(0) > 0) { + grad_weight[plane] = static_cast(res.v2 * factor); + } + if (grad_bias.size(0) > 0) { + grad_bias[plane] = static_cast(res.v1); + } + if (sum_dy.size(0) > 0) { + sum_dy[plane] = static_cast(res.v1); + } + if (sum_dy_xmu.size(0) > 0) { + sum_dy_xmu[plane] = static_cast(res.v2); + } + } +} + +template +__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t m_c = mean[plane]; + stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct; + stat_accscalar_t factor_1_c = invstd[plane]; + stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + factor_2_c *= factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct; + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto g_i = grad_input[batch][plane]; + auto g_o = grad_output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + g_i[feature] = static_cast((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c); + } + } +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const int* __restrict__ numel, const int world_size) { + int64_t total_numel = 0; + for (int i = 0; i < world_size; i ++) { + total_numel += numel[i]; + } + + const stat_accscalar_t norm_fct = + static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor get_packed_accessor( + const Tensor& t, std::string_view var_name) { + constexpr auto expect_type = c10::CppTypeToScalarType>::value; + const auto actual_type = t.scalar_type(); + TORCH_CHECK(actual_type == expect_type, "Expected ", var_name, + " to have type ", expect_type, " but got ", actual_type); + return t.generic_packed_accessor(); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor packed_accessor_or_dummy( + const Tensor& t, std::string_view var_name) { + if (!t.defined()) { + const std::array zeros{{0}}; + return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); + } + return get_packed_accessor(t, var_name); +} + +template +std::tuple batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, + const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, + bool train, double epsilon, std::array grad_input_mask) { + + using accscalar_t = at::acc_type; + Tensor grad_input_; + Tensor grad_input_reshaped; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (grad_input_mask[0]) { + grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); + } + if (grad_input_mask[1]) { + grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[2]) { + grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + auto input = get_packed_accessor< + const input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + const input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_input = packed_accessor_or_dummy< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto running_mean = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var"); + auto save_mean = packed_accessor_or_dummy< + const accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = packed_accessor_or_dummy< + const accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd"); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + + batch_norm_backward_kernel <<>> + (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, + save_mean, save_invstd, train, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(grad_input_, grad_weight_, grad_bias_); +} + +template +void batch_norm_stats_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) { + + using accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor dummy_mean_; + Tensor dummy_var_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + + resize_output(out_mean, {n_input}); + resize_output(out_invstd, {n_input}); + auto input = get_packed_accessor< + const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + batch_norm_collect_statistics_kernel <<>> + (input, epsilon, 0.0, mean, invstd); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_, + const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1}); + + auto input = get_packed_accessor< + const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + auto output = get_packed_accessor< + input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight"); + auto bias = packed_accessor_or_dummy< + const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + // NOTE: We use transform_input_kernel in training mode, which ignores epsilon + const double dummy_epsilon = 1e-5; + + // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_transform_input_kernel <<>> + (input, output, mean, invstd, weight, bias, dummy_epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +std::tuple batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_, + const Tensor& running_mean_, const Tensor& running_var_, + double momentum, double epsilon, const Tensor& counts_) { + + Tensor save_mean_; + Tensor save_invstd_; + + auto features = mean_.size(1); + auto input_options = mean_.options(); + if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) { + input_options = input_options.dtype(ScalarType::Float); + } + save_mean_ = at::empty({features}, input_options); + save_invstd_ = at::empty({features}, input_options); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto running_mean = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean"); + auto counts = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts"); + + auto save_mean = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + int block = getNumThreads(features); + int grid = std::max(1, features/block); + batch_norm_reduce_statistics_kernel <<>> + (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(save_mean_, save_invstd_); +} + +template +std::tuple batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, + const bool input_g, const bool weight_g, const bool bias_g) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor sum_dy_; + Tensor sum_dy_xmu_; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (input_g) { + sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (weight_g) { + grad_weight_ = at::empty({n_input}, weight_.options()); + } + if (bias_g) { + grad_bias_ = at::empty({n_input}, weight_.options()); + } + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto batch_size = input_reshaped.size(0); + auto feature_size = input_reshaped.size(2); + auto stream = at::cuda::getCurrentCUDAStream(); + + int warp_size = at::cuda::warp_size(); + int block_y = std::min(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size); + // We want block_x to be at least a warp width + int block_x = std::min(std::max(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y); + const dim3 block(block_x, block_y); + const dim3 grid(n_input); + + batch_norm_backward_reduce_kernel <<>> + (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + auto reduction_size = input_.numel() / n_input; + auto norm_fct = static_cast(1.0 / reduction_size); + batch_norm_backward_elemt_kernel + <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_backward_elemt_kernel <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr(), count.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance +// original apex name: welford_kernel_c_last +template + +__global__ void +batch_norm_collect_statistics_channels_last_kernel( + const scalar_t* __restrict__ input, + accscalar_t* __restrict__ out_mean, + accscalar_t* __restrict__ out_invstd, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride, + accscalar_t epsilon) { + // hide latency with concurrency + accscalar_t x_mean[PARALLEL_LOADS]; + accscalar_t m_2_n[PARALLEL_LOADS]; + int count[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + x_mean[i] = accscalar_t(0); + m_2_n[i] = accscalar_t(0); + count[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_math[PARALLEL_LOADS]; + accscalar_t x_count_inv[PARALLEL_LOADS]; + accscalar_t is_valid[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_math[j] = input[address_base]; + count[j]++; + x_count_inv[j] = accscalar_t(1) / count[j]; + is_valid[j] = accscalar_t(1); + } else { + x_math[j] = accscalar_t(0); + x_count_inv[j] = accscalar_t(0); + is_valid[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate mean/m2n with welford +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + accscalar_t delta0 = x_math[j] - x_mean[j]; + x_mean[j] += delta0 * x_count_inv[j]; + accscalar_t delta1 = x_math[j] - x_mean[j]; + m_2_n[j] += delta0 * delta1 * is_valid[j]; + } + } + + // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); + } + + // release x_mean / m_2_n + auto mean_th = x_mean[0]; + auto m2_th = m_2_n[0]; + auto count_th = count[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; + static __shared__ int shmem_count[MAX_BLOCK_SIZE]; + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_mean = staging_data; + volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; + volatile int* staging_count = reinterpret_cast(&staging_m2n[stride*gridDim.y]); + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_mean[address_base] = mean_th; + staging_m2n[address_base] = m2_th; + staging_count[address_base] = count_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + count_th = 0; + mean_th = accscalar_t(0.0); + m2_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + int count_new = c_offset < stride ? staging_count[address_base] : 0; + accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); + accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); + + welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new); + } + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + if (threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_forward_c_last_kernel +template < + typename scalar_t, + typename accscalar_t, + typename layerscalar_t, + int PARALLEL_LOADS> +__global__ void batch_norm_transform_input_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ z, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const layerscalar_t* __restrict__ shift, + scalar_t* __restrict__ out, + const int reduction_size, + const int stride, + const bool fuse_relu) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto inv_std_c = static_cast(inv_std[c_offset]); + auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast(weight[c_offset]); + auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast(shift[c_offset]); + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c; + if (z != nullptr) { + tmp += z[address_base]; + } + out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast(tmp)); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template +__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, + T& sum_dy_xmu, + T* shmem_sum_dy, + T* shmem_sum_dy_xmu) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_sum_dy[address_base] = sum_dy; + shmem_sum_dy_xmu[address_base] = sum_dy_xmu; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + + sum_dy += shmem_sum_dy[address]; + sum_dy_xmu += shmem_sum_dy_xmu[address]; + } + } +} + +// batchnorm backward kernel for c last tensor +// original apex name: reduce_bn_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_reduce_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ grad_output, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + accscalar_t* __restrict__ sum_dy_o, + accscalar_t* __restrict__ sum_dy_xmu_o, + layerscalar_t* __restrict__ grad_weight, + layerscalar_t* __restrict__ grad_bias, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride) { + + // hide latency with concurrency + accscalar_t sum_dy[PARALLEL_LOADS]; + accscalar_t sum_dy_xmu[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + sum_dy[i] = accscalar_t(0); + sum_dy_xmu[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + auto r_mean = mean[c_offset]; + auto factor = inv_std[c_offset]; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_input[PARALLEL_LOADS]; + accscalar_t x_grad_output[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_input[j] = input[address_base]; + x_grad_output[j] = grad_output[address_base]; + } else { + x_input[j] = accscalar_t(0); + x_grad_output[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate sum_dy / sum_dy_xmu +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + sum_dy[j] += x_grad_output[j]; + sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); + } + } + + // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + sum_dy[0] += sum_dy[j]; + sum_dy_xmu[0] += sum_dy_xmu[j]; + } + + // release array of registers + auto sum_dy_th = sum_dy[0]; + auto sum_dy_xmu_th = sum_dy_xmu[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_sum_dy = staging_data; + volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_sum_dy[address_base] = sum_dy_th; + staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + sum_dy_th = accscalar_t(0.0); + sum_dy_xmu_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); + sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); + } + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + if (threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_backward_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto m_dy_c = sum_dy[c_offset] * norm_fct; + auto factor_1_c = inv_std[c_offset]; + auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast(weight[c_offset])) * factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + grad_input[address_base] = static_cast( + (static_cast(grad_output[address_base]) - m_dy_c - + (static_cast(input[address_base]) - m_c) * factor_1_c) + * factor_2_c); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + const int* __restrict__ numel, + scalar_t* __restrict__ grad_input, + const int64_t world_size, + const int reduction_size, + const int stride) { + + int64_t total_numel = 0; + for (int i = 0; i < world_size; i++) { + total_numel += numel[i]; + } + + auto norm_fct = static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template +void batch_norm_stats_channels_last_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) { + using accscalar_t = at::acc_type; + + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + resize_output(out_mean, {stride}); + resize_output(out_invstd, {stride}); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({4*stride*grid.y}, out_mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_collect_statistics_channels_last_kernel + <<>>( + input.const_data_ptr(), + out_mean.mutable_data_ptr(), + out_invstd.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride, + epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void batch_norm_elemt_channels_last_cuda_template( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& shift, // bias of BN + const at::Tensor& mean, + const at::Tensor& inv_std, + const std::optional& z = std::nullopt, // bias after BN + const bool fuse_relu = false) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + const auto second_dtype = weight.defined() ? weight.scalar_type() : + (shift.defined() ? shift.scalar_type() : input.scalar_type()); + + if (input.scalar_type() != second_dtype) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr() : nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()){ + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr(): nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +std::tuple +batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const bool input_g, const bool weight_g, const bool bias_g) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::Tensor sumn_dy = at::empty({stride}, mean.options()); + at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); + + at::Tensor grad_weight; + at::Tensor grad_bias; + if (weight.defined()) { + grad_weight = at::empty({stride}, weight.options()); + grad_bias = at::empty({stride}, weight.options()); + } else { + // because I cannot return an uninitialized at::Tensor + grad_weight = at::empty({0}, mean.options()); + grad_bias = at::empty({0}, mean.options()); + } + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({2*stride*grid.y}, mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && input.scalar_type() != weight.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + grad_weight.mutable_data_ptr(), + grad_bias.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + weight.defined() ? grad_weight.mutable_data_ptr() : nullptr, + weight.defined() ? grad_bias.mutable_data_ptr() : nullptr, + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias); +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu, + const at::Tensor& count) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return grad_input; +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + auto norm_fct = 1.0 / reduction_size; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + + return grad_input; +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f0871fa0ead6fd19f31b45623864393741080b3e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh @@ -0,0 +1,402 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace { + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. +// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. +// The template arguments have the following meaning: +// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples. +// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small. +// A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp. +// This is important because it means only __shfl_ instructions are required for reductions. +// Note that this means WARP_SIZE must be a power of two and <= architecture warp size. +// CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch. +// ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs. +// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed. +// is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed. +// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t. +// This allows SoftMax to be fused with a cast immediately following the SoftMax. +// The mask should have the same shape as input, with a boolean indicate if the value is masked. +// The head_chunk_size is only used for transformer mask softmax, equals to H * D * D. +// For instance: +// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor. +// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor. +// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor. + +template +__global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + int idx_offset = first_batch * stride + local_idx; + + src += idx_offset; + dst += idx_offset; + + if (is_transformer_mask) { + mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx; + } else { + mask += idx_offset; + } + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + elements[i][it] = src[i*element_count+it*WARP_SIZE]; + } else { + elements[i][it] = -std::numeric_limits::infinity(); + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + bool is_meaningful_max = false; + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (is_masked) { + int idx = it*WARP_SIZE; + if ((idx + local_idx) < batch_element_count) { + if (!is_transformer_mask) { + idx += i*element_count; + } + if (!mask[idx]) { + max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + is_meaningful_max = true; + } + } + } else { + max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it]; + } + } + if (is_masked) { + if (!is_meaningful_max) { + max_value[i] = -std::numeric_limits::infinity(); + } + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (!is_masked) { + if (is_log_softmax) { + sum[i] += std::exp(elements[i][it] - max_value[i]); + } else { + elements[i][it] = std::exp(elements[i][it] - max_value[i]); + sum[i] += elements[i][it]; + } + } else { + int idx = it*WARP_SIZE; + bool valid = (idx + local_idx) < batch_element_count; + if (!is_transformer_mask) { + idx += i*element_count; + } + if (valid) { + if (!mask[idx]) { + if (is_log_softmax) { + sum[i] += std::exp(elements[i][it] - max_value[i]); + } else { + elements[i][it] = std::exp(elements[i][it] - max_value[i]); + sum[i] += elements[i][it]; + } + } else { + if (!is_log_softmax) { + // Masked values are treated as -infinity, and std::exp(-infinity) is 0. + elements[i][it] = 0; + } + } + } else { + if (!is_log_softmax) { + elements[i][it] = 0.; + } + } + } + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + if (is_log_softmax) sum[i] = std::log(sum[i]); + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_log_softmax) { + dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i]; + } else if (sum[i] == 0) { + dst[i*element_count+it*WARP_SIZE] = std::numeric_limits::quiet_NaN(); + } else { + dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i]; + } + } else { + break; + } + } + } +} + +template +__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x % WARP_SIZE; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + if (is_masked) { + mask += thread_offset; + } + + // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop, + // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep + // the nested loops. + // This should have no impact on performance because the loops are unrolled anyway. + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]; + output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; + } else { + grad_reg[i][it] = acc_t(0); + output_reg[i][it] = acc_t(0); + } + } + } + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) { + sum[i] += grad_reg[i][it]; + } + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_masked && mask[i*element_count+it*WARP_SIZE]) { + gradInput[i*element_count+it*WARP_SIZE] = 0; + } + // compute gradients + else if (is_log_softmax) { + gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); + } else { + gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); + } + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = at::cuda::warp_size(); + warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ + softmax_warp_forward \ + <<>>(dst, \ + src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024 + LAUNCH_SOFTMAX_WARP_FORWARD(11); // 2048 + default: + break; + } + } +} + +template +void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = at::cuda::warp_size(); + warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ + softmax_warp_backward \ + <<>> \ + (grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements, mask); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; + + LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024 + default: + break; + } + } +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Pow.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Pow.cuh new file mode 100644 index 0000000000000000000000000000000000000000..dc9faf77f22a35960eb3ddd9f3c903af239206ed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Pow.cuh @@ -0,0 +1,58 @@ +#pragma once +#include +#include + +namespace at::native { + +namespace { + + +// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. +// So we need to define the functions with the explicit function signatures. +// As for pow, the following signatures are defined as the device function: +// pow(float, int) +// pow(double, int) +// pow(float, float) +// pow(double, double) +#ifdef _MSC_VER +// Functions for pow +// pow for at::Half +static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow for at::BFloat16 +static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow (floating, floating/int) +template +static inline __host__ __device__ typename std::enable_if_t && (std::is_same_v || std::is_same_v), Base_type> + pow_(Base_type base, Exp_type exp) { + return std::pow(base, exp); +} +// pow (Otherwise) +template +static inline __host__ __device__ typename std::enable_if_t && !std::is_same_v, Base_type> + pow_(Base_type base, Exp_type exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +#else +template +static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { + return ::pow(base, exp); +} +#endif + +template +static inline __host__ __device__ std::enable_if_t, T> pow_( + T base, T exp) { + return at::native::powi(base, exp); +} + +template +static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { + return c10_complex_math::pow(base, exp); +} + +} // namespace +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Randperm.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Randperm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bdcc5a576be593f9defebd95c0675a23ffc7e10f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Randperm.cuh @@ -0,0 +1,58 @@ +#include +#include +#include + +#include +#include +#include + +namespace { + +// See note [Algorithm of randperm] +template +__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + + // find the beginning of islands + if (tid >= n - 1) return; // out of range + if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island + if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island + + // find the size of islands + int island_size = 0; + do { island_size++; } + while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask)); + + // do random permutation inside each island. + data += tid; + const auto [seed, offset] = at::cuda::philox::unpack(philox_args); + curandStatePhilox4_32_10_t state; + curand_init(seed, tid, offset, &state); + for (int i = island_size - 1; i > 0; i--) { + unsigned int r = curand(&state) % (i + 1); + if (i != r) { + scalar_t tmp = data[i]; + data[i] = data[r]; + data[r] = tmp; + } + } +} + +// See note [Algorithm of randperm] +template +void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional &gen_) { + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); + int64_t counter_offset = n; + at::PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + T mask = static_cast((1UL << bits) - 1); + randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( + keys, data, mask, n, rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Reduce.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..15a572804af5fccd8b2a181c930f6579ffdafa3f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Reduce.cuh @@ -0,0 +1,1402 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +static inline int64_t div_up(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +// returns floor(log2(n)) +static inline int last_pow2(int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +// returns reduced fraction numerator & denominator +C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; +} + +//template for changing MAX_NUM_THREADS based on op dtype +template +struct mnt_wrapper { + static constexpr int MAX_NUM_THREADS = 512; +}; + +template <> +struct mnt_wrapper >{ + static constexpr int MAX_NUM_THREADS = 256; +}; + +constexpr int max_reduce_threads(c10::ScalarType type) { + return type == kComplexDouble ? 256 : 512; +} + +struct ReduceConfig { + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs) + : element_size_bytes(element_size_bytes) + , num_inputs(num_inputs) + , num_outputs(num_outputs) {} + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + template + void set_block_dimension(int64_t dim0, int64_t dim1) { + const int max_num_threads = mnt_wrapper::MAX_NUM_THREADS / output_vec_size; + int dim0_pow2 = dim0 < max_num_threads ? static_cast(last_pow2(dim0)) : max_num_threads; + int dim1_pow2 = dim1 < max_num_threads ? static_cast(last_pow2(dim1)) : max_num_threads; + block_width = std::min(dim0_pow2, int(at::cuda::warp_size())); + block_height = std::min(dim1_pow2, int(max_num_threads / block_width)); + block_width = std::min(dim0_pow2, int(max_num_threads / block_height)); + num_threads = block_width * block_height; + } + + int split_input(int parallelism) { + int step = step_input; + step_input *= parallelism; + return step; + } + + int split_output(int parallelism) { + int step = step_output; + step_output *= parallelism; + return step; + } + + dim3 block() const { + return dim3(block_width, block_height); + } + + dim3 grid() const { + return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output); + } + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + int shared_memory_size() const { + if (!should_block_y_reduce() && + (!should_block_x_reduce() || + block_width <= at::cuda::warp_size())) { + return 0; + } + return element_size_bytes * num_threads * output_vec_size; + } + + int64_t global_memory_size() const { + if (!should_global_reduce()) { + return 0; + } + auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output; + if (!should_block_x_reduce()) { + size *= block().x * output_vec_size; + } + return size; + } + + int semaphore_size() const { + if (!should_global_reduce()) { + return 0; + } + return sizeof(int) * grid().x; + } + + int values_per_thread() const { + return div_up(num_inputs, step_input); + } +}; + +std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void reduce_kernel(R reduction) { + reduction.template run(); +} + +template +static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) { + int num_reduce_dims = iter.num_reduce_dims(); + int num_output_dims = iter.ndim() - num_reduce_dims; + int input_index = iter.ntensors() - 1; + int output_index = 0; + std::array strides = { + iter.strides(output_index).data() + num_reduce_dims, + iter.strides(input_index).data() + num_reduce_dims, + }; + auto shape = iter.shape().data() + num_reduce_dims; + return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data()); +} + +template +static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) { + int num_reduce_dims = iter.num_reduce_dims(); + int input_index = iter.ntensors() - 1; + std::array strides = { + iter.strides(input_index).data(), + }; + return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data()); +} + +template +struct func_wrapper_t { + using arg_t = typename binary_function_traits::arg1_t; + using scalar_t = typename binary_function_traits::arg2_t; + + func_t combine; + static inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + static __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + func_wrapper_t(const func_t& op) : combine(op) { + } + + // wrap a normal reduction that ignores the index + __device__ arg_t reduce(arg_t acc, scalar_t val, int64_t idx) const { + return combine(acc, val); + } +}; + +template +func_wrapper_t func_wrapper(const func_t& op) { + return func_wrapper_t { op }; +} + +template +struct ReduceJitOp { +//ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations +//Maybe we can find a way to unify ReduceOp and ReduceJitOp + using InputCalculator = OffsetCalculator<1, uint32_t>; + using OutputCalculator = OffsetCalculator<2, uint32_t>; + //TODO for now arg_t is always opmath_t of the input, later we'll need to change it + using arg_t = at::opmath_type; + + //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor, + //not just wrapper + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + ReduceJitOp( + ReduceConfig config, + InputCalculator input_calc, + OutputCalculator output_calc, + const void* src, + char* dst0, + std::optional dst1, + void* acc_buf, + void* cta_buf, + int* semaphores, + arg_t ident, + int noutputs, + int64_t base_idx) + : ident(ident), + config(config), + input_calc(input_calc), + output_calc(output_calc), + src(src), + acc_buf(acc_buf), + cta_buf(cta_buf), + semaphores(semaphores), + base_idx(base_idx), + noutputs(noutputs) { + dst[0] = dst0; + if (dst1.has_value()) { + dst[1] = dst1.value(); + } + } +}; + +template +struct ReduceOp { + using traits = function_traits; + using arg_t = typename std::decay::type>::type; + + using InputCalculator = OffsetCalculator<1, index_t>; + using OutputCalculator = OffsetCalculator<2, index_t>; + + static constexpr bool can_accumulate_in_output = + std::is_convertible_v + && std::is_convertible_v; + + ops_t ops; + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + ReduceOp( + ops_t ops, + ReduceConfig config, + InputCalculator input_calc, + OutputCalculator output_calc, + const void* src, + char* dst0, + std::optional dst1, + void* acc_buf, + void* cta_buf, + int* semaphores, + arg_t ident, + int noutputs, + int64_t base_idx) + : ops(ops), + ident(ident), + config(config), + input_calc(input_calc), + output_calc(output_calc), + src(src), + acc_buf(acc_buf), + cta_buf(cta_buf), + semaphores(semaphores), + base_idx(base_idx), + noutputs(noutputs) { + dst[0] = dst0; + if (dst1.has_value()) { + dst[1] = dst1.value(); + } + } + + template + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + index_t output_idx = config.output_idx(); + index_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = std::array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + value = thread_reduce(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + + using out_ptr_vec_t = std::array; + using offset_vec_t = std::array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE std::array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + CUDA_KERNEL_ASSERT(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](index_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](index_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + index_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(at::native::memory::aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = at::native::memory::aligned_vector; + + index_t idx = config.input_idx(); + const index_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + while (idx * input_vec_size + input_vec_size - 1 < end) { + const auto values_vec = memory::load_vector(data, idx); + #pragma unroll + for (index_t i = 0; i < input_vec_size; i++) { + value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + index_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + const auto value = c10::load(data + idx); + value_list[0] = ops.reduce(value_list[0], value, idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = ops.combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE std::array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + index_t idx = config.input_idx(); + const index_t end = config.num_inputs; + const index_t stride = config.step_input; + + using arg_vec_t = std::array; + using load_t = at::native::memory::aligned_vector; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + const auto offset = calc(idx + i * stride) / output_vec_size; + values[i] = memory::load_vector(data_, offset); + } + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + const auto offset = calc(idx) / output_vec_size; + values[i] = memory::load_vector(data_, offset); + idx += stride; + } + idx = idx_; + #pragma unroll + for (index_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (index_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + + template + C10_DEVICE std::array block_x_reduce(std::array value, char* shared_memory) const { + using args_vec_t = std::array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = ops.warp_shfl_down(value[i], offset); + value[i] = ops.combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE std::array block_y_reduce(std::array value, char* shared_memory) const { + using args_vec_t = std::array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE std::array accumulate_in_output( + std::array out, + std::array value, + typename std::enable_if_t* = nullptr + ) const { + std::array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = ops.combine(*(out[i]), value[i]); + } + return ret; + } + + template + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value, + typename std::enable_if_t* = nullptr + ) const { + CUDA_KERNEL_ASSERT(!final_output); + return (out_scalar_t)value; + } + + // This function should never be called -- + // it's the version of `accumulate_in_output` + // when accumulation in the output is not possible. + template + C10_DEVICE std::array accumulate_in_output( + std::array, + std::array, + typename std::enable_if_t* = nullptr + ) const { + CUDA_KERNEL_ASSERT(false); + return {arg_t{}}; + } + + // This function should never be called -- + // it's the version of `get_accumulated_output` + // when accumulation in the output is not possible. + template + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value, + typename std::enable_if_t* = nullptr + ) const { + CUDA_KERNEL_ASSERT(false); + return *out; + } + + template + C10_DEVICE void set_results(const T x, const index_t base_offset) const { + CUDA_KERNEL_ASSERT(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + + //Currently implemented for max of two outputs + template + C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { + if (noutputs >= 1) { + auto res0 = (T1*)((char*)dst[0] + base_offset); + *res0 = x.first; + } + if (noutputs >= 2) { + // base offset is computed assuming element size being sizeof(T1), so we need to make a + // correction to obtain the correct base offset + auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); + *res1 = x.second; + } + } + + template + C10_DEVICE void set_results_to_output(std::array value, std::array base_offset) const { + CUDA_KERNEL_ASSERT(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(ops.project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE std::array global_reduce(std::array value, std::array *acc, char* shared_memory) const { + using arg_vec_t = std::array; + using out_ptr_vec_t = std::array; + using offset_vec_t = std::array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + index_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + index_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + __threadfence(); // complete the acquire pattern after atomic + for (auto &v : value) { + v = ident; + } + if (config.should_block_x_reduce()) { + index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + index_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + index_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[i]); + } + } + } else { + index_t input_offset = threadIdx.y; + index_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + index_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +template +static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) { + dim3 block = config.block(); + dim3 grid = config.grid(); + + auto stream = at::cuda::getCurrentCUDAStream(); + int shared_memory = config.shared_memory_size(); + + switch(config.output_vec_size) { + case 4: + reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + reduce_kernel<<>>(reduction); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +inline void launch_jitted_reduce_kernel( + std::mutex &jiterator_mutex, + std::array &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, + int vt0, const ReduceConfig& config, const void *reduction) { + dim3 block = config.block(); + dim3 grid = config.grid(); + + int shared_memory = config.shared_memory_size(); + at::cuda::jit::NvrtcFunction* fn_ptr; + switch(config.output_vec_size) { + case 4: + fn_ptr = &fn_cache[0]; + break; + case 2: + fn_ptr = &fn_cache[1]; + break; + default: + fn_ptr = &fn_cache[2]; + } + if (!fn_ptr->function) { + int max_threads_codegen = + max_reduce_threads(desc.f_inputs_type) / config.output_vec_size; + auto code = at::cuda::jit::generate_reduction_code( + desc, vt0, true, false, config.output_vec_size, max_threads_codegen); + + *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_" + desc.name); + } + constexpr int kernel_args = 1; + const void* args[kernel_args]; + args[0] = reduction; + at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory); +} + + +class AccumulationBuffer { + public: + AccumulationBuffer() {} + + AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) { + out_ptr_ = (char*)out_ptr; + if (out_t_size >= acc_t_size) { + // reusing output buffer for accumulation. + acc_ptr_ = (char*)out_ptr; + numerator_ = 1; + denominator_ = 1; + } else { + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + buffer_ = allocator.allocate(size); + acc_ptr_ = (char*)buffer_.get(); + numerator_ = acc_t_size; + denominator_ = out_t_size; + reduce_fraction(numerator_, denominator_); + } + } + + char* get_acc_slice(char* out_ptr) { + if (acc_ptr_ == nullptr) { + return nullptr; + } + return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_); + } + + private: + char* acc_ptr_ = nullptr; + char* out_ptr_ = nullptr; + size_t numerator_; + size_t denominator_; + at::DataPtr buffer_; +}; + +template +int get_output_vec_size(const TensorIterator &iter) { + int vec_size = 4; + auto update_vec_size = [&vec_size](uint64_t n) { + while(n % vec_size != 0) { + vec_size /= 2; + } + }; + + uint64_t base_address = reinterpret_cast(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t); + update_vec_size(base_address); + + const int output_index = iter.num_reduce_dims(); + update_vec_size(iter.shape()[output_index]); + + int j = 0; + for(auto i : iter.strides(iter.noutputs())) { + if (j != output_index) { + update_vec_size(i / sizeof(scalar_t)); + } + j++; + } + return vec_size; +} + +template +ReduceConfig setReduceConfig(const TensorIterator& iter){ + // Start by assuming that each thread handles a single output and all + // the inputs for that output. + int64_t num_outputs = iter.num_output_elements(); + int64_t inputs_per_output = iter.numel() / num_outputs; + int input_index = iter.ntensors() - 1; + + auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output); + + int64_t dim0; + int64_t dim1; + int64_t fastest_moving_stride; + bool reduction_on_fastest_striding_dimension; + + if (iter.ndim() > 0) { + // Adjust block size to map block width to fastest changing dimension of input + // tensor. This grants the best possible memory accessing pattern, given that + // for non-contiguous tensor with space in between, we cannot have perfect + // memory coalescing. + reduction_on_fastest_striding_dimension = + (iter.num_reduce_dims() == iter.ndim()) || + (iter.strides(/*arg=*/input_index)[0] < + iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]); + // Notice that dim0 & dim1 does NOT guarantee any launch configuration here! + // dim0 & dim1 are more like the upper bound of the block dimension. The + // actual launch config and reduction scheme is determined by setting values + // to `config.input_mult` and `config.output_mult`. + // We try to max out dim1 so that we have enough threads per CTA to deliver + // performance for larger problem size. + if (reduction_on_fastest_striding_dimension) { + // Map block.x to the fastest reducing dimension. It implies: + // 1. block_x_reduce is required. + // 2. block.y now max out to num_outputs. + dim0 = inputs_per_output; + dim1 = num_outputs; + fastest_moving_stride = iter.strides(/*arg=*/input_index)[0]; + } else { + // Map block.x to the fastest non reducing dimension. It implies: + // 1. block_x_reduce is turned off. + // 2. block.y now max out to inputs_per_output. + dim0 = num_outputs; + dim1 = inputs_per_output; + fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]; + } + } else { + reduction_on_fastest_striding_dimension = true; + fastest_moving_stride = sizeof(scalar_t); + dim0 = 1; + dim1 = 1; + } + + // We do vectorization to gain better memory access, there are two cases which we call + // "vectorize along input" and "vectorize along output". Note that the "input/output" + // here does not mean we are vectorizing load/store instructions. We always only vectorize + // load instructions. + // + // Case 1: "vectorize along input" + // This case happens when we are reducing along fastest moving dimesion. In such case, threads + // with the same threadIdx.y works on the same reduction cooperatively and will produce results + // for the same output. In such case, values in each loaded vector always correspond to the same output. + // + // Case 2: "vectorize along output" + // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case, + // threads with different threadIdx.x are independent and will produce results for different outputs. + // In such case, values in each loaded vector always correspond to different outputs. + if (fastest_moving_stride == sizeof(scalar_t)) { +#ifdef USE_ROCM + if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1) { +#else + if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) { +#endif + // Case 1: "vectorize along input" + // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case, + // we should avoid vectorization. + config.vectorize_input = true; + dim0 /= input_vec_size; + } else if (!reduction_on_fastest_striding_dimension) { + // Case 2: "vectorize along output" + config.output_vec_size = get_output_vec_size(iter); + dim0 /= config.output_vec_size; + } + } + + // Adjust block_width and block_height + config.set_block_dimension(dim0, dim1); + + int block_width = config.block_width; + int block_height = config.block_height; + + if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) { + // Split the input across lanes if the input is contiguous in the reduced + // dimension. This will require reduction between threads using warp + // shuffle instructions and shared memory (if block_width > warpSize). + config.input_mult[0] = config.split_input(block_width); + } else { + // Otherwise split the output across lanes in a warp. + config.output_mult[0] = config.split_output(block_width); + } + + constexpr int min_values_per_thread = 16; + constexpr int max_values_per_thread = 256; + + const int warp_split_threshold = + std::min(block_height * 16, max_values_per_thread); + bool split_across_warps = config.values_per_thread() >= warp_split_threshold; + const int num_mp = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount; +#ifdef USE_ROCM + bool force_splitting_output = iter.ndim() == 2 && + reduction_on_fastest_striding_dimension && + config.values_per_thread() < 1024 && num_mp < 100; + split_across_warps = !force_splitting_output && split_across_warps; +#endif + + if (split_across_warps) { + // Divide the input across warps in a thread-block, if that leaves at least + // 16 elements to be summed by each thread. This will require inter-warp + // reduction using shared memory. + config.input_mult[1] = config.split_input(block_height); + } else { + // Otherwise, each warp handles a separate output. + config.output_mult[1] = config.split_output(block_height); + } + + int max_threads_per_mp = + at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; +#ifdef USE_ROCM + // If the grid consists of a single threadblock, do not change the max threads per + // MP value. This will increase the parallelism across the y dimension of the grid. + bool uses_a_single_block = config.grid().x == config.grid().y == config.grid().z == 1; + + if (!uses_a_single_block) { + // Control the number of threadblocks by adjusting the maximum number of + // threads per multi-processor. These numbers better reflect the maximum + // theoretical achievable threads per MP for the reduction operation. + if (iter.ndim() == 1 || iter.ndim() == 3) + max_threads_per_mp = 512; + else if (iter.ndim() == 2) + max_threads_per_mp = 256; + } +#endif + const int blocks_per_sm = max_threads_per_mp / config.num_threads; + const int target_grid_size = num_mp * blocks_per_sm; + int grid = config.grid().x; + if (config.input_mult[1] != 0 && config.values_per_thread() >= max_values_per_thread && grid <= target_grid_size) { + // Divide the input across thread-blocks if the amount of work per-thread + // is large enough and the size of the output is small enough. This will + // require a reduction using global memory. + // If we decide to split input across blocks, as long as we can get enough + // number of blocks (`target_grid_size`) to balance SM, we should still + // make the number of values per thread large for best performance. + int ctas_per_output1 = div_up(target_grid_size, grid); + int ctas_per_output2 = div_up(config.values_per_thread(), min_values_per_thread); + int ctas_per_output3 = div_up(config.values_per_thread(), max_values_per_thread); + // We want the minimum of ctas_per_output1 and ctas_per_output2, so that each thread can have + // a large number of values to deal with. But we don't want values_per_thread to be larger than + // max_values_per_thread + config.ctas_per_output = std::max(std::min(ctas_per_output1, ctas_per_output2), ctas_per_output3); +#ifdef USE_ROCM + // In cases where a number of threadblocks along the y direction of the grid + // is needed then make sure they are reduced to the number of MPs. For + // smaller sizes, use half the number of MPs. For smaller sizes than half + // the number of MPs use the original value unless the value is less than 16 + // blocks in which case it is more profitable to use just 1 block. + if (config.ctas_per_output > num_mp) + if (num_mp < 128) + config.ctas_per_output = + num_mp * (config.ctas_per_output > 512 ? 4 : 2); + else + config.ctas_per_output = num_mp; + else if (config.ctas_per_output > div_up(num_mp, 2)) + config.ctas_per_output = div_up(num_mp, 2); + else if (config.ctas_per_output < 16) + config.ctas_per_output = 1; + bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) + config.ctas_per_output = 4; +#endif + if (config.ctas_per_output > 1) { + config.input_mult[2] = config.split_input(config.ctas_per_output); + } + } + return config; +}; + +template +inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0, + AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) { + AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1); + + using traits = function_traits; + using arg_t = typename traits::template arg<0>::type; + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same_v && + std::is_same_v) || + (std::is_same_v, scalar_t> && + std::is_same_v, out_scalar_t>); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same_v && + std::is_same_v); + static constexpr bool can_accumulate_in_output = + std::is_convertible_v && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); + + bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); + std::unique_ptr owned_buf_ptr; + // The acc_buf_ptr is a shared pointer. It is create at the first entrance and + // reused by all recursive function calls. + if (acc_buf_ptr == NULL) { + // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter + // when accumulation in output is not possible. + if (!can_accumulate_in_output && !can_use_32bit_indexing) { + int64_t output_memory_size = iter.element_size(0); + for (int dim = 0; dim < iter.ndim(); dim++) { + output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]); + } + output_memory_size /= iter.element_size(0); //iter.strides is in bytes + owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t), + sizeof(out_scalar_t), + (char*) iter.data_ptr(0), + output_memory_size * sizeof(arg_t))); + } else { + owned_buf_ptr.reset(new AccumulationBuffer()); + } + acc_buf_ptr = owned_buf_ptr.get(); + } + + if (!can_use_32bit_indexing) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + int64_t sub_iter_base_idx = sub_iter.view_offsets()[0]; + + gpu_reduce_kernel(sub_iter, ops, ident, + acc_buf_ptr, sub_iter_base_idx); + } + return; + } + + const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1); + char* out_data = (char*)iter.data_ptr(0); + const auto noutputs = iter.noutputs(); + std::optional out_data_extra; + if (noutputs > 1) { + out_data_extra = (char*)iter.data_ptr(1); + } else { + out_data_extra = std::nullopt; + } + char* acc_data = acc_buf_ptr->get_acc_slice(out_data); + + ReduceConfig config = setReduceConfig(iter); + at::DataPtr buffer; + at::DataPtr semaphores; + if (config.should_global_reduce()) { + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + buffer = allocator.allocate(config.global_memory_size()); + semaphores = allocator.allocate(config.semaphore_size()); + + auto stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream)); + } + + AT_ASSERT(can_use_32bit_indexing); + auto output_calc = make_output_calculator(iter); + auto input_calc = make_input_calculator(iter); + auto reduce = ReduceOp( + ops, + config, + input_calc, + output_calc, + in_data, + out_data, + out_data_extra, + acc_data, + buffer.get(), + (int*)semaphores.get(), + ident, + noutputs, + base_idx); + reduce.accumulate = iter.should_accumulate(); + reduce.final_output = iter.is_final_output(); + + launch_reduce_kernel::MAX_NUM_THREADS>(config, reduce); +} + +//TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function +//try unifying with gpu_reduce_kernel +template +inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0, + AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) { + AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1); + + //TODO - this will be different for more complicated reductions, but for now reductions using + //func_wrapper all have arg_t = opmath + using arg_t = at::opmath_type; + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same_v && + std::is_same_v ) || + (std::is_same_v, scalar_t> && + std::is_same_v, out_scalar_t>); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same_v && + std::is_same_v); + static constexpr bool can_accumulate_in_output = + std::is_convertible_v && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); + + bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); + std::unique_ptr owned_buf_ptr; + + // The acc_buf_ptr is a shared pointer. It is create at the first entrance and + // reused by all recursive function calls. + if (acc_buf_ptr == NULL) { + // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter + // when accumulation in output is not possible. + if (!can_accumulate_in_output && !can_use_32bit_indexing) { + int64_t output_memory_size = iter.element_size(0); + for (int dim = 0; dim < iter.ndim(); dim++) { + output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]); + } + output_memory_size /= iter.element_size(0); //iter.strides is in bytes + owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO + sizeof(out_scalar_t), + (char*) iter.data_ptr(0), + output_memory_size * sizeof(out_scalar_t))); //TODO + } else { + owned_buf_ptr.reset(new AccumulationBuffer()); + } + acc_buf_ptr = owned_buf_ptr.get(); + } + + if (!can_use_32bit_indexing) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + int64_t sub_iter_base_idx = sub_iter.view_offsets()[0]; + + jitted_gpu_reduce_kernel(sub_iter, func, ident, + acc_buf_ptr, sub_iter_base_idx); + } + return; + } + + //TODO - for now we support a single input, we may be able to relax this constraint + const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1); + char* out_data = (char*)iter.data_ptr(0); + const auto noutputs = iter.noutputs(); + std::optional out_data_extra; + if (noutputs > 1) { + out_data_extra = (char*)iter.data_ptr(1); + } else { + out_data_extra = std::nullopt; + } + char* acc_data = acc_buf_ptr->get_acc_slice(out_data); + + ReduceConfig config = setReduceConfig(iter); + + at::DataPtr buffer; + at::DataPtr semaphores; + if (config.should_global_reduce()) { + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + buffer = allocator.allocate(config.global_memory_size()); + semaphores = allocator.allocate(config.semaphore_size()); + + auto stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream)); + } + + AT_ASSERT(can_use_32bit_indexing); + auto output_calc = make_output_calculator(iter); + auto input_calc = make_input_calculator(iter); + auto reduce = ReduceJitOp( + config, + input_calc, + output_calc, + in_data, + out_data, + out_data_extra, + acc_data, + buffer.get(), + (int*)semaphores.get(), + ident, + noutputs, + base_idx); + reduce.accumulate = iter.should_accumulate(); + reduce.final_output = iter.is_final_output(); + + constexpr int nInputs = 1; + constexpr int nOutputs = 1; + static auto desc = at::cuda::jit::make_kernel_descriptor< + out_scalar_t, scalar_t>(name, func, nInputs, nOutputs); + + static std::mutex jiterator_mutex; + static std::vector> fn_cache(c10::cuda::device_count()); + auto &cache = fn_cache[iter.device().index()]; + + launch_jitted_reduce_kernel( + jiterator_mutex, cache, desc, vt0, config, &reduce); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ReduceOps.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ReduceOps.h new file mode 100644 index 0000000000000000000000000000000000000000..99c66812982e8e1b39c6807051d053a617aad170 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ReduceOps.h @@ -0,0 +1,20 @@ + +namespace at { +struct TensorIterator; +} + +namespace c10 { +class Scalar; +} + +namespace at::native { + +void norm_launch_kernel(TensorIterator &iter, double val); +void min_launch_kernel(TensorIterator &iter); +void max_launch_kernel(TensorIterator &iter); +void aminmax_launch_kernel(TensorIterator &iter); +void min_all_launch_kernel(TensorIterator &iter); +void max_all_launch_kernel(TensorIterator &iter); +void aminmax_allreduce_launch_kernel(TensorIterator &iter); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Resize.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Resize.h new file mode 100644 index 0000000000000000000000000000000000000000..b2c3efe5a719257a3303d8f07bb1a3b612c93dbf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Resize.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace at::native { + +TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); + +static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) { + // It does not make sense to try to resize a storage + // to hold 0 elements, and this can break + // if storage_offset is positive but + // new_size is 0, so just bail in that case + // (same comment is in Resize.h) + if (self->numel() == 0) { + return; + } + + const Storage &storage = self->unsafe_storage(); + TORCH_CHECK(storage, "Tensor: invalid null storage"); + if (new_size_bytes > storage.nbytes()) { + resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes); + } +} + +inline TensorImpl* resize_impl_cuda_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride) { + if (self->sizes() == size && (!stride || self->strides() == stride)) { + return self; + } + const auto itemsize = self->dtype().itemsize(); + const auto storage_offset = self->storage_offset(); + size_t storage_size = 1; + if (stride) { + self->set_sizes_and_strides(size, *stride); + storage_size = at::detail::computeStorageNbytes( + size, *stride, itemsize, storage_offset); + } else { + self->set_sizes_contiguous(size); + storage_size = at::detail::computeStorageNbytesContiguous( + size, itemsize, storage_offset); + } + maybe_resize_storage_cuda(self, storage_size); + + return self; +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h new file mode 100644 index 0000000000000000000000000000000000000000..533a702f301e89f41953f57e3ba2e4766ce52dd9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScaledGroupMM.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScaledGroupMM.h new file mode 100644 index 0000000000000000000000000000000000000000..380851df538b41760a7ce3d7fd72fbe314d8269f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScaledGroupMM.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_grouped_mm( + at::Tensor mat_a, // FP8 + at::Tensor mat_b, // FP8 + at::Tensor scale_a, // FP32 + at::Tensor scale_b, // FP32 + std::optional offs, + std::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScanKernels.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScanKernels.h new file mode 100644 index 0000000000000000000000000000000000000000..28e65372511bc7b50390a134e01554b6fa9ee171 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScanKernels.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; + +namespace native { + +// NOTE: these functions require output tensors to be contiguous +void launch_cummax_cuda_kernel(const TensorBase& self, const TensorBase& values, + const TensorBase& indices, int64_t dim); +void launch_cummin_cuda_kernel(const TensorBase& self, const TensorBase& values, + const TensorBase& indices, int64_t dim); +void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); +void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); +void launch_cumprod_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim); + +}} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c4d86acb43e7bb2652374431995762064a9658a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh @@ -0,0 +1,475 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native { + +template +constexpr inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) { + integer log_num_threads_x = 0; + integer log_num_threads_y = 0; + while (((integer)1 << log_num_threads_x) < row_size) { + ++log_num_threads_x; + } + while (((integer)1 << log_num_threads_y) < num_rows) { + ++log_num_threads_y; + } + // we want to keep the ratio between the x-threads and y-threads about the same as + // the ratio between the row_size and num_rows, but the total number of threads in + // a block should be about 512 + integer diff = log_num_threads_x - log_num_threads_y; + // 9 is from log2(512) + log_num_threads_x = ((integer)9 + diff) / (integer)2; + // I found that in having larger log_num_threads_x can give significant speed up in some cases, + // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it + // similar to the previous implementation + // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block. + log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9); + return log_num_threads_x; +} + +template +__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) { + if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) { + rhs = lhs; + rhs_idx = lhs_idx; + } +} +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + int num_rows, int row_size, + const uint32_t num_threads, const uint32_t log_num_threads_x, + scalar_t init, BinaryFunction binary_op) { + // dynamic memory allocation for vbuf and ibuf + alignas(sizeof(double)) extern __shared__ char buf[]; + scalar_t* vbuf = reinterpret_cast(buf); // the size is num_threads * 2 + int64_t* ibuf = reinterpret_cast(vbuf + num_threads * 2); + const uint32_t num_threads_x = 1 << log_num_threads_x; + scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y; + int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y; + + for (int block_row = blockIdx.x * blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + int row = block_row + threadIdx.y; + const scalar_t *row_self = self_ + row * row_size; + scalar_t *row_values = values_ + row * row_size; + int64_t *row_indices = indices_ + row * row_size; + scalar_t block_total = init; + int64_t block_idx_final = 0; + const bool row_exists = row < num_rows; + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + int col1 = block_col + threadIdx.x; + int col2 = block_col + num_threads_x + threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = c10::load(&row_self[col1]); + row_idx_buf[threadIdx.x] = col1; + } else { + row_buf[threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]); + row_idx_buf[num_threads_x + threadIdx.x] = col2; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (uint32_t s = 1; s <= num_threads_x; s <<= 1) { + if (row_exists) { + uint32_t a = (threadIdx.x / s) * (2 * s) + s; + uint32_t ti = a + (threadIdx.x % s); + uint32_t si = a - 1; + binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size){ + row_values[col1] = row_buf[threadIdx.x]; + row_indices[col1] = row_idx_buf[threadIdx.x]; + } + if (col2 < row_size) { + row_values[col2] = row_buf[num_threads_x + threadIdx.x]; + row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x]; + } + } + block_total = row_buf[2 * num_threads_x - 1]; + block_idx_final = row_idx_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to compute the variance; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *self = self_ + orow * row_size * num_irows + irow; + scalar_t *values = values_ + orow * row_size * num_irows + irow; + int64_t *indices = indices_ + orow * row_size * num_irows + irow; + scalar_t out = init; + int64_t out_idx = 0; + + for (auto col = decltype(row_size){0}; col < row_size; ++col) { + const auto val = c10::load(self); + if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) { + out = val; + out_idx = col; + } + *values = out; + *indices = out_idx; + self += num_irows; + values += num_irows; + indices += num_irows; + } + } + } +} + +inline void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + +template +__host__ void scan_outer_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + int dim, scalar_t init, BinaryFunction binary_op) { + int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + tensor_kernel_scan_outer_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__host__ void scan_innermost_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int row_size = self.size(ndim - 1); + int num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y)))); + + const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t)); + tensor_kernel_scan_innermost_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_rows, row_size, num_threads, log_num_threads_x, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) { + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous()); + if (dim == ndim - 1) { + scan_innermost_dim_with_indices(*self_, values, indices, init, binary_op); + } else { + scan_outer_dim_with_indices(*self_, values, indices, dim, init, binary_op); + } +} + +// TODO: The implementation of `tensor_kernel_scan_outer_dim` and +// `tensor_kernel_scan_innermost_dim` is similar to +// `tensor_kernel_scan_outer_dim_with_indices` +// `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to +// remove the duplication. + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to scan; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) +{ + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *src = src_ + orow * row_size * num_irows + irow; + scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; + scalar_t acc = init; + + for (uint32_t col = 0; col < row_size; ++col) { + acc = binary_op(acc, c10::load(src)); + *tgt = acc; + + src += num_irows; + tgt += num_irows; + } + } + } +} + +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_, + const uint32_t num_rows, const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, BinaryFunction binary_op){ + const index_t num_threads_x = 1 << log_num_threads_x; + for (index_t block_row = blockIdx.x * (index_t) blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + index_t row = block_row + (index_t) threadIdx.y; + T block_total = init; + + const T *row_src = src_ + row * row_size; + T *row_tgt = tgt_ + row * row_size; + const bool row_exists = row < num_rows; + + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (index_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + index_t col1 = block_col + (index_t) threadIdx.x; + index_t col2 = block_col + num_threads_x + (index_t) threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = row_src[col1]; + } else { + row_buf[threadIdx.x] = init; + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = row_src[col2]; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + row_buf[0] = binary_op(row_buf[0], block_total); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (int m = 0; m <= log_num_threads_x; ++m) { + if (row_exists) { + index_t s = 1 << m; // s = 2 ^ m + auto a = static_cast((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s + index_t ti = a + (threadIdx.x % s); + index_t si = a - 1; + row_buf[ti] = binary_op(row_buf[ti], row_buf[si]); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; + if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; + } + block_total = row_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +template < + typename T, + class BinaryFunction> +__global__ void tensor_kernel_scan_innermost_dim( + T* tgt_, + const T* src_, + const uint32_t num_rows, + const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, + BinaryFunction binary_op) { + alignas(sizeof(double)) extern __shared__ char sbuf[]; + T* sbuf2 = reinterpret_cast(sbuf); + const uint32_t num_threads_x = 1 << log_num_threads_x; + T* row_buf = reinterpret_cast(sbuf2 + num_threads_x * 2 * threadIdx.y); + if (num_rows * (size_t) row_size <= UINT_MAX) { + tensor_kernel_scan_innermost_dim_impl( + row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op); + } else { + tensor_kernel_scan_innermost_dim_impl( + row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op); + } +} + + +template +__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result, + int dim, scalar_t init, BinaryFunction binary_op) { + const int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_outer_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_innermost_dim(const TensorBase& self, const TensorBase& result, + scalar_t init, BinaryFunction binary_op) { + int64_t ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int64_t row_size = self.size(ndim - 1); + int64_t num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; + dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y}))); + + check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_innermost_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_rows, row_size, log_num_threads_x, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim(const TensorBase& self, const TensorBase& result, + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(result.is_contiguous()); + + if (self.numel() == self.size(dim)) { + if constexpr (std::is_same_v>) { + if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms()) && (self.is_floating_point() || self.is_complex())) { +#if defined(CUDA_VERSION) || defined(USE_ROCM) + cuda::cub::inclusive_deterministic_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); +#else + globalContext().alertNotDeterministic("cumsum_cuda_kernel"); + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); +#endif + } else { + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } + } else { + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } + } else if (dim == ndim - 1) { + scan_innermost_dim(*self_, result, init, binary_op); + } else { + scan_outer_dim(*self_, result, dim, init, binary_op); + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sort.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sort.h new file mode 100644 index 0000000000000000000000000000000000000000..4ad3e26b819b9a6b9eb240d1fdb0f6c0bca264b2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sort.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include +#include + + +namespace at::native { + +inline bool should_use_small_sort(const TensorBase &self, int64_t dim) { + return self.size(dim) <= 4096; +} + +void sortKeyValueInplace( + const TensorBase &key, const TensorBase &value, int64_t dim, + bool descending, bool stable=false); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortStable.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortStable.h new file mode 100644 index 0000000000000000000000000000000000000000..d186345b3a93070e0b3d2b354dab12d9f71eae78 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortStable.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include + +namespace at::native { + +// Stable-sort self into values, and set indices to the +// inverse-permutation from values back to self. +// Output tensors must be pre-allocated and contiguous. +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1d0df3741bb223df83110ff6c62d81d9681352f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh @@ -0,0 +1,343 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include + +#define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600) + + +namespace at::native { + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, + K& kB, V& vB, bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSort(K *keys, + V *values, + bool *valid, + const Comparator& comp) { +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + flag, comp); + } + } + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + false, comp); + } + + __syncthreads(); + +} + +// at::cuda::detail::TensorInfo version +// Sorts (key, value) pairs (in different tensors) in-place; i.e., +// modifies the input `keys` and `values` +template +C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y) +__global__ void +bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If the entire block is out of bounds exit early + if (blockIndex * blockDim.y >= keySlices) { + return; + } + // It's also possible for some rows of a block to be out of bounds + // but all thread need to run for __syncthreads to work. + const bool row_valid = linearIndex < keySlices; + + constexpr int items_per_thread = 2; + constexpr int Power2SortSize = block_dim_x * items_per_thread; + + // Storage for max_block_dim_y sorts performed in parallel + __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; + __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; + __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; + + auto sharedKeys = blockSharedKeys[threadIdx.y]; + auto sharedValues = blockSharedValues[threadIdx.y]; + auto sharedValid = blockSharedValid[threadIdx.y]; + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + // Load 2 values per thread into the shared workspace + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + bool valid = row_valid && idx < keySliceSize; + + sharedKeys[idx] = valid ? + keys.data[idx * keySliceStride + keyStartOffset] : K{}; + sharedValues[idx] = valid ? + values.data[idx * valueSliceStride + valueStartOffset] : V{}; + sharedValid[idx] = valid; + } + + // Sort! + bitonicSort( + sharedKeys, sharedValues, sharedValid, comp); + + if (!row_valid) { + return; + } + + // Store outputs + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + if (idx < keySliceSize) { + keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; + values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx]; + } + } +} + +#if HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) +__global__ void +warpMergeSortKVInPlace( + at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp, + K invalid_key) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If this row is out of bounds exit early + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); + CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); + constexpr int items_per_thread = sort_size / C10_WARP_SIZE; + static_assert( + items_per_thread * C10_WARP_SIZE == sort_size, + "sort_size must be a multiple of C10_WARP_SIZE"); + + + using LoadKeys = cub::WarpLoad; + using LoadValues = cub::WarpLoad; + using Sort = cub::WarpMergeSort; + using StoreKeys = cub::WarpStore; + using StoreValues = cub::WarpStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage[max_block_dim_y]; + + auto& warp_storage = tmp_storage[threadIdx.y]; + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + const auto invalid_value = V{}; + LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + WARP_SYNC(); + LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + WARP_SYNC(); + + // Sort! We use stable sort to ensure that invalid values are never + // sorted before valid values. In testing it performed the same as + // .Sort, so there is no down-side. + Sort(warp_storage.sort).StableSort( + local_keys, local_values, comp, keySliceSize, invalid_key); + WARP_SYNC(); + + // Store outputs + StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + WARP_SYNC(); + StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +#endif // HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(block_size) +__global__ void +radixSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + bool descending) { + static_assert(block_size > 0, ""); + + // Find the slice of the tensor that we are sorting + const IndexType linearIndex = getLinearBlockId(); + // Tiling the slices could have us be out of bounds, if there are a + // lot of slices to sort + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + using key_t = typename at::cuda::cub::detail::cuda_type::type; + using LoadKeys = cub::BlockLoad; + using LoadValues = cub::BlockLoad; + using Sort = cub::BlockRadixSort; + using StoreKeys = cub::BlockStore; + using StoreValues = cub::BlockStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage; + + // cub's Block operations operate on a fixed number of items, but the + // actual slice we are sorting might be smaller. So, we need to make + // up the difference with keys that will always sort higher. + const K invalid_key = [descending] { + using radix_t = typename cub::Traits::UnsignedBits; + union { + K key; + radix_t radix; + } tmp; + tmp.radix = descending ? + cub::Traits::LOWEST_KEY : + cub::Traits::MAX_KEY; + return tmp.key; + }(); + const V invalid_value = static_cast(0); + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + __syncthreads(); + LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + __syncthreads(); + + // Sort! + if (descending) { + Sort(tmp_storage.sort).SortDescending( + reinterpret_cast(local_keys), + local_values); + } else { + Sort(tmp_storage.sort).Sort( + reinterpret_cast(local_keys), + local_values); + } + __syncthreads(); + + // Store outputs + StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + __syncthreads(); + StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sorting.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..667ea2c2629218f9e59b8f059d2ab9a76fbabb8c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sorting.h @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_kthvalue_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t k); +void launch_median_kernel( + const TensorBase &vals, const TensorBase &inds, + const TensorBase &in, int64_t dim, bool ignore_nan); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ba2f6f7c38e527526ad8d9762310f4112b647891 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh @@ -0,0 +1,193 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// Is this questionable namespace pollution? +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; + +#else +constexpr int MAX_BLOCK_SIZE = 1024; +#endif + +// Maximum size per grid dimension that we assume (compute capability >= 2.0) +constexpr int64_t MAX_GRID_SIZE = 65535LL; + +inline bool getGridFromTiles(int64_t gridTiles, dim3& grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + + grid = dim3(gridX, gridY, gridZ); + return true; +} + +template +struct GTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || + (static_cast(lhs) > static_cast(rhs)); + } +}; + +template +struct LTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || + (static_cast(lhs) < static_cast(rhs)); + } +}; + +template +__device__ __forceinline__ index_t getLinearBlockId() { + return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + + blockIdx.x; +} + +// For slice sorting in Thrust; extracts a slice index from a linear +// index and uses that for comparison +struct SliceComp { + SliceComp(int64_t size) : sliceSize(size) {} + + __device__ bool operator()(const int64_t& a, const int64_t& b) const { + // Since the slices are guaranteed to be innermost, + // the segment is just via int64_t division + int64_t segA = a / sliceSize; + int64_t segB = b / sliceSize; + return segA < segB; + } + + const int64_t sliceSize; +}; + +// For sorting in Thurst; extracts a within-slice index from a linear index +struct GlobalIndexToPerSliceIndex { + GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {} + + __device__ inline void operator()(int64_t& v) const { + v = v % sliceSize; + } + + const int64_t sliceSize; +}; + +// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks +inline uint64_t nextHighestPowerOf2(uint64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; +#ifndef _MSC_VER + n |= n >> 32; +#endif + n++; + + return n; +} + + +// WARNING: This function assumes input tensors are contiguous +template +void run_launcher( + const TensorBase &values, + const TensorBase &indices, + const TensorBase &self, + int64_t dim, + Launcher l) { + auto self_info = cuda::detail::getTensorInfo(self); + auto values_info = cuda::detail::getTensorInfo(values); + auto indices_info = cuda::detail::getTensorInfo(indices); + + int64_t slice_size = self.size(dim); + /* We use these structures solely to find the offset to */ + /* each slice we are operating on */ + self_info.reduceDim(dim); + values_info.reduceDim(dim); + indices_info.reduceDim(dim); + + /* Collapse all other dims */ + int collapse_self_dim = self_info.collapseDims(dim); + int collapse_values_dim = values_info.collapseDims(dim); + int collapse_indices_dim = indices_info.collapseDims(dim); + + int64_t num_slices = 1; + for (int i = 0; i < self_info.dims; ++i) { + num_slices *= self_info.sizes[i]; + } + + /* This is used as a template parameter to calculate indices. */ + /* We only specialize it if all collapsed dim sizes are the */ + /* same; otherwise, we use -1 which is the specialization */ + /* parameter for arbitrary dimensions */ + int all_dims = self_info.dims; + if (values_info.dims != all_dims || indices_info.dims != all_dims) { + all_dims = -1; + } + + if (all_dims == 1) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 2) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 3) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e1496d4828a02e73257e945cea6de824d87ae354 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh @@ -0,0 +1,427 @@ +#include +#include +#include +#include +#include + +namespace at::native { + +template +struct TopKTypeConfig {}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + // Converts a float to an integer representation with the same + // sorting; i.e., for floats f1, f2: + // if f1 < f2 then convert(f1) < convert(f2) + // We use this to enable radix selection of floating-point values. + // This also gives a relative order for NaNs, but that's ok, as they + // will all be adjacent + // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff.. + // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00.. + // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0 + // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(uint8_t v) { + return v; + } + + static inline __device__ uint8_t deconvert(RadixType v) { + return v; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int8_t v) { + return 128u + v; + } + + static inline __device__ int8_t deconvert(RadixType v) { + return v - 128; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int16_t v) { + static_assert(sizeof(short) == 2, ""); + return 32768u + v; + } + + static inline __device__ int16_t deconvert(RadixType v) { + return v - 32768; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int32_t v) { + static_assert(sizeof(int) == 4, ""); + return 2147483648u + v; + } + + static inline __device__ int32_t deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(int64_t v) { + static_assert(sizeof(int64_t) == 8, ""); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } + + static inline __device__ double deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::Half v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType x = __half_as_ushort(v); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; +#else + CUDA_KERNEL_ASSERT(false); + return 0u; +#endif + } + + static inline __device__ at::Half deconvert(RadixType v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + return __ushort_as_half(v ^ mask); +#else + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +#endif + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::BFloat16 v) { + RadixType x = v.x; + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ at::BFloat16 deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + at::BFloat16 r; + r.x = (v ^ mask); + return r; + } +}; + +// This function counts the distribution of all input values in a +// slice we are selecting by radix digit at `radixDigitPos`, but only +// those that pass the filter `((v & desiredMask) == desired)`. +// This produces and broadcasts the seen counts for a single block only. +// `smem` must have at least `RadixSize` elements. +template < + typename scalar_t, + typename bitwise_t, + typename index_t, + typename CountType, + int RadixSize, + int RadixBits> +__device__ void countRadixUsingMask( + CountType counts[RadixSize], + CountType* smem, + bitwise_t desired, + bitwise_t desiredMask, + int radixDigitPos, + index_t sliceSize, + index_t withinSliceStride, + const scalar_t* data) { + // Clear out per-thread counts from a previous round +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + + if (threadIdx.x < RadixSize) { + smem[threadIdx.x] = 0; + } + __syncthreads(); + + // Scan over all the data. Upon a read, the warp will accumulate + // counts per each digit in the radix using warp voting. +#if !defined(USE_ROCM) + // Must be called outside of loop to ensure all threads participate + unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize); +#endif + for (index_t i = threadIdx.x; i < sliceSize;) { + bitwise_t val = + TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); + + bool hasVal = ((val & desiredMask) == desired); + bitwise_t digitInRadix = at::cuda::Bitfield::getBitfield( + val, radixDigitPos, RadixBits); + +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = hasVal && (digitInRadix == j); +#if defined(USE_ROCM) + counts[j] += __popcll(WARP_BALLOT(vote)); +#else + counts[j] += __popc(WARP_BALLOT(vote, mask)); +#endif + } + i += blockDim.x; +#if !defined(USE_ROCM) + mask = WARP_BALLOT(i < sliceSize, mask); +#endif + } + + // Now, for each warp, sum values + if (at::cuda::getLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + gpuAtomicAddNoReturn(&smem[i], counts[i]); + } + } + + __syncthreads(); + + // For each thread, read in the total counts +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = smem[i]; + } + + __syncthreads(); +} + +// Over what radix we are selecting values +constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +// This finds the unique value `v` that matches the pattern +// ((v & desired) == desiredMask) in our sorted int format +template +__device__ scalar_t findPattern( + scalar_t* smem, + const scalar_t* data, + index_t sliceSize, + index_t withinSliceStride, + bitwise_t desired, + bitwise_t desiredMask) { + if (threadIdx.x < 2) { + smem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + // All threads participate in the loop, in order to sync on the flag + index_t numIterations = + round_up(sliceSize, static_cast(blockDim.x)); + for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? doLdg(&data[i * withinSliceStride]) + : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + // There should not be conflicts if we are using findPattern, + // since the result is unique + smem[0] = static_cast(1); + smem[1] = v; // can't use val as the flag, since it could be 0 + } + + __syncthreads(); + + scalar_t found = smem[0]; + scalar_t val = smem[1]; + + __syncthreads(); + + // Check to see if a thread found the value + if (found != static_cast(0)) { + // all threads return this value + return val; + } + } + + // should not get here + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +} + +// Returns the top-Kth element found in the data using radix selection +template +__device__ void radixSelect( + const scalar_t* data, + index_t k, + bool largest, + index_t sliceSize, + index_t withinSliceStride, + int* smem, + scalar_t* topK) { + // Per-thread buckets into which we accumulate digit counts in our + // radix + int counts[RADIX_SIZE]; + + // We only consider elements x such that (x & desiredMask) == desired + // Initially, we consider all elements of the array, so the above + // statement is true regardless of input. + bitwise_t desired = 0; + bitwise_t desiredMask = 0; + + // We are looking for the top kToFind-th element when iterating over + // digits; this count gets reduced by elimination when counting + // successive digits + int kToFind = k; + + // We start at the most significant digit in our radix, scanning + // through to the least significant digit + for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0; + digitPos -= RADIX_BITS) { + // Count radix distribution for the current position and reduce + // across all threads + countRadixUsingMask< + scalar_t, + bitwise_t, + index_t, + int, + RADIX_SIZE, + RADIX_BITS>( + counts, + smem, + desired, + desiredMask, + digitPos, + sliceSize, + withinSliceStride, + data); + + auto found_unique = [&](int i, int count) -> bool { + /* All threads have the same value in counts here, so all */ + /* threads will return from the function. */ + if (count == 1 && kToFind == 1) { + /* There is a unique answer. */ + desired = at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The answer is now the unique element v such that: */ + /* (v & desiredMask) == desired */ + /* However, we do not yet know what the actual element is. We */ + /* need to perform a search through the data to find the */ + /* element that matches this pattern. */ + *topK = findPattern( + (scalar_t*)smem, + data, + sliceSize, + withinSliceStride, + desired, + desiredMask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= kToFind) { + desired = + at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The top-Kth element v must now be one such that: */ + /* (v & desiredMask == desired) */ + /* but we haven't narrowed it down; we must check the next */ + /* least-significant digit */ + return true; + } + kToFind -= count; + return false; // continue the loop + }; + + // All threads participate in the comparisons below to know the + // final result + if (largest) { + // Process in descending order +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } else { + // Process in ascending order +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } + } // end digitPos for + + // There is no unique result, but there is a non-unique result + // matching `desired` exactly + *topK = TopKTypeConfig::deconvert(desired); +} +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e9c1bfd6d004cf81888f1a03dc1aa3083342133c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh @@ -0,0 +1,431 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +// Used for a segmented reduction +struct ModeUnsignedBoolPair { + unsigned int val; + bool flag; +}; + +// In the kernel below, we have a common pattern of reducing (unsigned int, +// unsigned int) pairs of data +struct ModeUnsignedPair { + unsigned int val; + unsigned int index; +}; + +// Inclusive Scan via an upsweep/downsweep mechanism. Assumes: +// +// 1. Power2ScanSize is a power of 2. This code still works for collections that +// do not exactly contain a power of 2 number of elements, simply round up to +// the nearest power of 2 and then call. +// +// 2. That there are two-elements per thread, i.e. the size of the smem storage +// is 2 * blockDim.x * sizeof(T). +// +// Consider a (+)-Scan on the following elements: +// +// Upsweep: +// +// 0 1 2 3 4 5 6 7 +// 1 5 9 13 +// 6 22 +// 28 +// +// Downsweep: +// 15 +// 3 10 21 +template +__device__ void inclusivePrefixScan(T* smem, BinaryOp binop) { + // Reduce step ("upsweep") +#pragma unroll + for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if (index < Power2ScanSize) { + smem[index] = binop(smem[index], smem[index - stride]); + } + __syncthreads(); + } + + // Post-reduce step ("downsweep") +#pragma unroll + for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if ((index + stride) < Power2ScanSize) { + smem[index + stride] = binop(smem[index + stride], smem[index]); + } + __syncthreads(); + } +} + +// Block-wide reduction where each thread locally reduces N +// values before letting a single warp take over - assumes +// threadVals is in registers, not shared memory +// +// If smem is not used again, there is no need to __syncthreads before this +// call. However, if smem will be used, e.g., this function is called in a loop, +// then __syncthreads is needed either before or afterwards to prevent non-0 +// threads overriding smem in the next loop before num-0 thread reads from it. +template +__device__ T reduceBlockWithNThreadLocalReductions( + T* smem, + T threadVals[N], + const unsigned int numVals, + ReduceOp reduceOp, + T init) { + int offset = threadIdx.x * N; + T local = offset < numVals ? threadVals[0] : init; + +#pragma unroll + for (int i = 1; i < N; ++i) { + ++offset; + T next = offset < numVals ? threadVals[i] : init; + local = reduceOp.combine(local, next); + } + + return cuda_utils::BlockReduce(local, reduceOp, init, smem); +} + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap( + K& kA, + V& vA, + bool& validA, + K& kB, + V& vB, + bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSwapKeys( + K& kA, + bool& validA, + K& kB, + bool& validB, + bool dir, + const Comparator& comp) { + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(validA, validB); + } +} + +template < + typename K, + typename IndexType, + int Power2SortSize, + typename Comparator> +__device__ inline void bitonicSortKeys( + K keys[Power2SortSize], + bool valid[Power2SortSize], + const Comparator& comp) { +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + flag, + comp); + } + } + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + false, + comp); + } + + __syncthreads(); +} + +// The mode kernel has the following characteristics: It uses internal shared +// memory buffers of Power2Size, which must be greater than the number of +// elements. Additionally, there is one block for every slice to calculate the +// mode for, and in each block there is one thread for every two elements. +// +// Both sorted and positions are assumed to be contiguous Tensors with the mode +// dimension as the innermost dim, such that we can get the particular slice for +// a Tensor via its linear block dimension * the slice size. +template +__launch_bounds__(1024, 1) +__global__ void compute_mode( + const T* input, + at::cuda::detail::TensorInfo values, + at::cuda::detail::TensorInfo indices, + int64_t sliceSize, + int64_t slices) { + int tidx = threadIdx.x; + int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for + + // First, we need to calculate the offset into the sorted Tensor that + // represents the start of the slice for this block to calculate the mode for. + // This offset is a combination of the gridIndices, and the number of elements + // in the slice. + unsigned int blockId = getLinearBlockId(); + unsigned int linearOffset = blockId * sliceSize; + + if (blockId >= slices) { + return; + } + + // shmem is a dynamically sized buffer we will use throughout the kernel to + // handle computation efficiently. The size of this shmem must be + // sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size) + // + // Initially, the buffer will be organized as follows: + // + // [smem (slice elements) | bmem (valid indices) | ] + extern __shared__ char shmem[]; + + // smem represents a proportion of the shared memory buffer that is used to + // store the elements from the slice: + T* smem = reinterpret_cast(shmem); + + // Each thread loads up to two elements from the Tensor into shared memory + if (tidx < sliceSize) { + smem[tidx] = c10::load(&input[linearOffset + tidx]); + } + if (stidx < sliceSize) { + smem[stidx] = c10::load(&input[linearOffset + stidx]); + } + + // Next, we initialize a boolean region of the buffer, offset by the loaded + // element smem region + bool* bmem = reinterpret_cast(&smem[Power2Size]); + + // The first use of this region stores bmem[i] = i < sliceSize to mark the + // valid components in the smem buffer + bmem[tidx] = tidx < sliceSize; + bmem[stidx] = stidx < sliceSize; + __syncthreads(); // barrier for smem, bmem initialization + + // First, sort the input slice in ascending order. smem contains the input + // elements, and bmem marks the valid indices + bitonicSortKeys( + smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) { + return a < b; + }); + __syncthreads(); // make no assumptions that the sort syncs at end + + // The next step of our algorithm is performing a block-wide comparison of + // neighboring elements. In particular, given an sorted input slice A, we + // produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise + // 0. + // + // Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8] + // B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // + // In particular, we can think of B[i] true indicating the start of a sequence + // of equal values in the sorted list. Similarly, we will also store the + // negation of B, which we'll call C. In particular, we can think of C[i] = + // true iff A[i-1] == A[i] in our original sorted slice. + // + // C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + + // We overwrite bmem, and treat the rest of shared memory as a buffer of + // (index, flag) pairs where the index represents values from C, and the flag + // represents values from B. + // + // [smem (sorted slice) | ubpmem (index, flag pairs)] + + struct ModeUnsignedBoolPair* ubpmem = + reinterpret_cast(&smem[Power2Size]); + + if (tidx == 0) { + ubpmem[0].flag = true; + ubpmem[0].val = 0; + } + + // Compares elements (0, 1), (2, 3), ... and sets 1, 3, ... + ubpmem[tidx * 2 + 1].flag = + smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc. + ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag; + + // Compares elements (1, 2), (3, 4), ... and sets 2, 4, ... + if (((tidx + 1) * 2) < Power2Size) { + ubpmem[(tidx + 1) * 2].flag = + smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2]; + ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag; + } + __syncthreads(); // barrier for ubpmem initialization + + // Next, we perform a segmented prefix sum on the neighboring elements, where + // the presence of a one indicates the start of a segment. In this case B acts + // as the segment start flags, and C is the buffer to be summed: + // + // Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + // Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0] + // + // Afterwards, the (index) components of the ubpmem buffer contain the lengths + // of the segments (minus 1), i.e. the counts of each element in the original + // input. + inclusivePrefixScan( + ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) { + ModeUnsignedBoolPair c; + c.val = a.flag ? a.val : a.val + b.val; + c.flag = a.flag | b.flag; + return c; + }); + // assumes scan syncs at the end + + // Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e. + // we treat the boolean flag regions as integers). We initialize these to + // represent indices, and we'll call this buffer I + struct ModeUnsignedPair* uupmem = + reinterpret_cast(ubpmem); + + // At this point, we need to find the maximum element in lengths buffer C. + // This element will represent the count (-1) of the mode. Because of the + // way we have set up the problem, the index where this mode occurs will + // also be the location of the mode value in the sorted array, e.g. + // + // smem = [0, 0, 1, 1, 1, 2] + // C = [0, 1, 0, 1, 2, 0] + // I = [0, 1, 2, 3, 4, 5] + // ^ + // maximum value, also aligned with mode = 1 + // + // We perform a block wide max-reduction of the C buffer, but we also need the + // indices to come along with it, so we utilize the uupmem construction. + // + // At the end we need to return the ModeUnsignedPair containing index = 4, val + // = 2, which represents the max + + // In practice, we will make each thread locally reduce 2 values in its + // registers prior to the global block-wide reduction. Note that instead of + // tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with + // adjacent elements. This is because the reduce code below relies on thread + // elements to be adjacent. + struct ModeUnsignedPair uup[2]; + uup[0].index = tidx * 2; + uup[0].val = ubpmem[tidx * 2].val; + uup[1].index = tidx * 2 + 1; + uup[1].val = ubpmem[tidx * 2 + 1].val; + __syncthreads(); + + struct ModeUnsignedPair max = {0, 0}; + + struct MaxOp { + inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const { + return b.val > a.val ? b : a; + } + + inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const { + ModeUnsignedPair ret; + ret.index = WARP_SHFL_DOWN(acc.index, offset); + ret.val = WARP_SHFL_DOWN(acc.val, offset); + return ret; + } + } max_op; + + max = reduceBlockWithNThreadLocalReductions<2>( + uupmem, + uup, + sliceSize, + max_op, + max); + + // Store the mode in shared memory for use in finding the mode in the input + // slice + __shared__ T mode; + + // Given the above constraints, the mode is the value at the reduced index in + // the original sorted element buffer + if (tidx == 0) { + mode = smem[max.index]; + } + __syncthreads(); // broadcast mode + + // Finally, we need to find "an" index of the mode in the input + // Tensor. The API does not constrain which index we pick, but here + // we always pick the largest index. We store the index if the value + // is the mode, or 0 otherwise. Then find the maximum value. + // + // Again we reduce 2 elements in the thread's registers prior to the + // block-wide reduction + unsigned mode_index[2] = {0u, 0u}; + if (tidx * 2 < sliceSize) { + const unsigned idx = tidx * 2; + mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; + } + if (tidx * 2 + 1 < sliceSize) { + const unsigned idx = tidx * 2 + 1; + mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; + } + + struct MaxIndexOp { + inline __device__ unsigned combine(unsigned a, unsigned b) const { + return b > a ? b : a; + } + + inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } + } max_index_op; + + int64_t index = reduceBlockWithNThreadLocalReductions<2>( + reinterpret_cast(&shmem[0]), + mode_index, + sliceSize, + max_index_op, + 0u); + + // Finally, we have the mode, and an index where it occurs. We use a single + // thread to place this in the appropriate output position + if (tidx == 0) { + unsigned int outputOffset = + at::cuda::detail::IndexToOffset::get( + blockId, values); + values.data[outputOffset] = mode; + indices.data[outputOffset] = index; + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..edf505fc4b8619d5c615e5b1cc36c648a0588ea2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_fused_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t slice_size, int64_t slices); + +void launch_apply_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t ndim); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorTopK.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorTopK.h new file mode 100644 index 0000000000000000000000000000000000000000..74b615d34d00fa97a9bfe3405ca59a015aa19ca7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorTopK.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { +void launch_gather_topk_kernel( + const TensorBase& self, + int64_t k, int64_t dim, bool largest, + const TensorBase& values, const TensorBase& indices); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e19ba94b3fb8e6e36e5fa0205e7cdcd824f8feec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh @@ -0,0 +1,12 @@ +#include + +namespace at::native::internal { + +template +std::tuple unique_cuda_template( + const Tensor& self, + const bool consecutive, + const bool return_inverse, + const bool return_counts); + +} // namespace at::native::internal diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UpSample.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UpSample.cuh new file mode 100644 index 0000000000000000000000000000000000000000..50428b377da85ddac718d05053b59617a39bdb27 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UpSample.cuh @@ -0,0 +1,368 @@ +#pragma once +#include +#include + +#include +#include +#include + +#include +#include + +namespace at::native { + +namespace upsample { +// TODO: Remove duplicate declaration. +TORCH_API c10::SmallVector compute_output_size( + c10::IntArrayRef input_size, // Full input tensor size. + at::OptionalIntArrayRef output_size, + std::optional> scale_factors); +} // namespace upsample + +namespace upsample_cuda { + +// TODO: Remove duplication with Upsample.h (CPU). +inline std::optional get_scale_value(std::optional> scales, int idx) { + if (!scales) { + return std::nullopt; + } + return scales->at(idx); +} + +} // namespace upsample_cuda + + +/* TODO: move this to a common place */ +template +__device__ inline scalar_t min(scalar_t a, scalar_t b) { + return a < b ? a : b; +} + +template +__device__ inline scalar_t max(scalar_t a, scalar_t b) { + return a > b ? a : b; +} + +// NOTE [ Nearest neighbor upsampling kernel implementation ] +// +// The nearest neighbor upsampling kernel implementation is symmetrical as +// expected. We launch kernels with threads mapping to destination tensors where +// kernels write data to, each thread reads data from the source tensor, this +// means: +// 1. In the forward kernel, +// src_xxx refers to properties of input tensors; +// dst_xxx refers to properties of output tensors; +// scale_factor is the ratio of src_size to dst_size; +// 2. In the backward kernel, +// src_xxx refers to properties of grad_output tensors; +// dst_xxx refers to properties of grad_input tensors; +// scale_factor is the ratio of src_size to dst_size; +// +// Because of this, we need to take the reciprocal of the scale defined by +// upsample layer during forward path. The motivation is to avoid slow +// division in the kernel code, so we can use faster multiplication instead. +// This is not necessary during backward path, since the scale_factor is already +// the reciprocal of corresponding scale_factor used in the forward path due to +// the swap of source and destination tensor. +// +// Similarly, since the mapping from grad_input to grad_output during backward +// is the reverse of the mapping of output to input, we need to have opposite +// mapping functions to compute the source index. + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +template +__host__ __forceinline__ accscalar_t compute_scales_value( + const std::optional scale, + int64_t src_size, + int64_t dst_size) { + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)(1.0 / scale.value()) + : (accscalar_t)src_size / dst_size; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +template +__host__ __forceinline__ accscalar_t compute_scales_value_backwards( + const std::optional scale, + int64_t src_size, + int64_t dst_size) { + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)scale.value() + : (accscalar_t)src_size / dst_size; +} + +template +__host__ __forceinline__ accscalar_t area_pixel_compute_scale( + int input_size, + int output_size, + bool align_corners, + const std::optional scale) { + if(align_corners) { + if(output_size > 1) { + return (accscalar_t)(input_size - 1) / (output_size - 1); + } + else { + return static_cast(0); + } + } + else{ + return compute_scales_value(scale, input_size, output_size); + } +} + +template +__device__ __forceinline__ accscalar_t area_pixel_compute_source_index( + accscalar_t scale, + int dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + accscalar_t src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + // See Note[Follow Opencv resize logic] + return (!cubic && src_idx < static_cast(0)) + ? static_cast(0) + : src_idx; + } +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_compute_source_index( + const float scale, + int dst_index, + int input_size) { + // index_f32 = (output_index) * scale + // input_index = round(index_f32) + // Same as a buggy OpenCV INTER_NEAREST + // We keep this method for BC and consider as deprecated. + // See nearest_neighbor_exact_compute_source_index as replacement + const int src_index = + min(static_cast(floorf((dst_index) * scale)), input_size - 1); + return src_index; +} + +__device__ __forceinline__ int nearest_neighbor_exact_compute_source_index( + const float scale, + int dst_index, + int input_size) { + // index_f32 = (output_index + 0.5) * scale - 0.5 + // input_index = round(index_f32) + // Same as Pillow and Scikit-Image/Scipy ndi.zoom + const int src_index = + min(static_cast(floorf((dst_index + static_cast(0.5)) * scale)), input_size - 1); + return src_index; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_bw_compute_source_index( + const float scale, + int dst_index, + int output_size) { + // Equivalent to buggy OpenCV INTER_NEAREST + // We keep this method for BC and consider as deprecated. + // See nearest_neighbor_exact_bw_compute_source_index as replacement + const int src_index = + min(static_cast(ceilf(dst_index * scale)), output_size); + return src_index; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_exact_bw_compute_source_index( + const float scale, + int dst_index, + int output_size) { + // Equivalent to Pillow and Scikit-Image/Scipy ndi.zoom + const int src_index = + min(static_cast(ceilf(dst_index * scale - static_cast(0.5))), output_size); + return src_index; +} + +/* Used by UpSampleBicubic2d.cu */ +template +__device__ __forceinline__ scalar_t upsample_get_value_bounded( + const PackedTensorAccessor64& data, + int batch, + int channel, + int height, + int width, + int y, + int x) { + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + return data[batch][channel][access_y][access_x]; +} + +/* Used by UpSampleBicubic2d.cu */ +template +__device__ __forceinline__ void upsample_increment_value_bounded( + PackedTensorAccessor64& data, + int batch, + int channel, + int height, + int width, + int y, + int x, + accscalar_t value) { + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + /* TODO: result here is truncated to scalar_t, + check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912 + */ + gpuAtomicAddNoReturn( + &data[batch][channel][access_y][access_x], static_cast(value)); +} + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +__device__ __forceinline__ accscalar_t cubic_convolution1( + accscalar_t x, + accscalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +__device__ __forceinline__ accscalar_t cubic_convolution2( + accscalar_t x, + accscalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +__device__ __forceinline__ void get_cubic_upsampling_coefficients( + accscalar_t coeffs[4], + accscalar_t t) { + accscalar_t A = -0.75; + + accscalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + accscalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +__device__ __forceinline__ accscalar_t cubic_interp1d( + scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + accscalar_t t) { + accscalar_t coeffs[4]; + get_cubic_upsampling_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +namespace upsample_antialias { + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L20-L29 +struct BilinearFilterFunctor { + + template + __device__ accscalar_t operator()(accscalar_t x) const { + if (x < 0) { + x = -x; + } + if (x < 1) { + return 1 - x; + } + return 0; + } + + static const int size = 2; +}; + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L46-L62 +struct BicubicFilterFunctor { + + template + __device__ accscalar_t operator()(accscalar_t x) const { + // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + const accscalar_t a = -0.5; + if (x < 0) { + x = -x; + } + if (x < 1) { + return ((a + 2) * x - (a + 3)) * x * x + 1; + } + if (x < 2) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0; + } + + static const int size = 4; +}; + +template +__device__ __forceinline__ void _compute_weights_span( + const int i, + const int input_size, + const accscalar_t scale, + const accscalar_t support, + int& xmin, + int& xsize, + accscalar_t& center) { + center = scale * (i + static_cast(0.5)); + xmin = max(static_cast(center - support + static_cast(0.5)), static_cast(0)); + xsize = min(static_cast(center + support + static_cast(0.5)), input_size) - xmin; +} + +template +__device__ __forceinline__ void _compute_weights( + scalar_t* wt_ptr, + const accscalar_t scale, + int interp_size, + const interp_filter_t& interp_filter, + accscalar_t xmin_m_center, + int xsize) { + + accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; + accscalar_t total_w = 0.0; + int j = 0; + for (j = 0; j < xsize; j++) { + accscalar_t w = interp_filter((j + xmin_m_center + static_cast(0.5)) * invscale); + wt_ptr[j] = static_cast(w); + total_w += w; + } + for (j = 0; j < xsize; j++) { + if (total_w != 0.0) { + wt_ptr[j] /= total_w; + } + } + for (; j < interp_size; j++) { + wt_ptr[j] = static_cast(0.0); + } +} + +template +__device__ __forceinline__ accscalar_t interpolate_aa_single_dim( + const scalar_t* src, + const scalar_t* weights, + int size) { + scalar_t t = static_cast(*src); + scalar_t wts = static_cast(weights[0]); + accscalar_t output = t * wts; + + int j = 1; + for (; j < size; j++) { + wts = static_cast(weights[j]); + t = static_cast(*(src + j)); + output += t * wts; + } + return output; +} + +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2a272d22c0c60edf450363f762c0cf47523bd1ed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh @@ -0,0 +1,139 @@ +#pragma once + +#include + +#include +#include + +namespace at::native::cuda_utils { + +constexpr int kCUDABlockReduceNumThreads = 512; +// Algorithmic limitation: BlockReduce does two WarpReduce calls, each +// of which reduces C10_WARP_SIZE elements. So, at most +// C10_WARP_SIZE**2 elements can be reduced at a time. +// NOTE: This is >= the max block size on current hardware anyway (1024). +constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; + +// Sums `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceSum(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN(val, offset); + } + return val; +} + +// Picks the maximum `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceMax(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset)); + } + return val; +} + +struct Block1D { + static __forceinline__ __device__ int Tid() { return threadIdx.x; } + + static __forceinline__ __device__ int Warps() { + return blockDim.x / C10_WARP_SIZE; + } +}; + +struct Block2D { + static __forceinline__ __device__ int Tid() { + return threadIdx.x + threadIdx.y * blockDim.x; + } + + static __forceinline__ __device__ int Warps() { + return blockDim.x * blockDim.y / C10_WARP_SIZE; + } +}; + +// Sums `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceSum(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceSum(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceSum(val); + } + return val; +} + +// Picks out the maximum `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceMax(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceMax(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits::lowest()); + if (wid == 0) { + val = WarpReduceMax(val); + } + return val; +} + +template +__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = op.combine(val, op.warp_shfl_down(val, offset)); + } + return val; +} + +template +__inline__ __device__ T +BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduce(val, op); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : identity_element; + if (wid == 0) { + val = WarpReduce(val, op); + } + return val; +} + +} // namespace at::native::cuda_utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0bf4da7a7be86c1697c86a2792fa3770c4caabf4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace at::cuda::detail { + +template +struct enable_2x_kernel_for_sm89 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 890 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm9x : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm10_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..43ce2999862787e6d0bb78c307078cfd62a35a61 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adam_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..676569a762cb5c602296f60e0b4205f96687cd00 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adam_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f6949e69f2ca80a39542201fc47a27b4a8fe9a99 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh @@ -0,0 +1,200 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at::native { + +enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; + +namespace { + +constexpr uint8_t kParamIdx = 0; +constexpr uint8_t kGradIdx = 1; +constexpr uint8_t kExpAvgIdx = 2; +constexpr uint8_t kExpAvgSqIdx = 3; +constexpr uint8_t kMaxExpAvgSqIdx = 4; + +template < + typename scalar_type, + typename opmath_t, + int depth, + ADAM_MODE adam_mode, + bool amsgrad> +C10_DEVICE inline void adam_math( + scalar_type r_args[depth][kILP], + const double& lr, + const double& beta1, + const double& beta2, + const double& weight_decay, + const double& eps, + const bool& maximize, + const float* grad_scale_ptr, + const float* found_inf_ptr, + const opmath_t& bias_correction1, + const opmath_t& bias_correction2_sqrt) { + static_assert(depth == 4 || depth == 5); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + // Load values. + opmath_t param = static_cast(r_args[kParamIdx][ii]); + opmath_t grad = static_cast(r_args[kGradIdx][ii]); + if (grad_scale_ptr) { + grad /= (static_cast(*grad_scale_ptr)); + } + const opmath_t grad_to_store = grad; + if (maximize) { + grad = -grad; + } + opmath_t exp_avg = static_cast(r_args[kExpAvgIdx][ii]); + opmath_t exp_avg_sq = static_cast(r_args[kExpAvgSqIdx][ii]); + opmath_t max_exp_avg_sq; + if (amsgrad) { + max_exp_avg_sq = static_cast(r_args[kMaxExpAvgSqIdx][ii]); + } + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + if constexpr (adam_mode == ADAM_MODE::ORIGINAL) { + grad += param * weight_decay; + } else if constexpr (adam_mode == ADAM_MODE::ADAMW) { + param -= lr * weight_decay * param; + } + } + // todo(crcrpar): use lerp + // ref: https://developer.nvidia.com/blog/lerp-faster-cuda/ + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const opmath_t step_size = lr / bias_correction1; + opmath_t denom; + if (amsgrad) { + max_exp_avg_sq = std::max(max_exp_avg_sq, exp_avg_sq); + denom = (std::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps; + } else { + denom = (std::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps; + } + param -= step_size * exp_avg / denom; + + // Store results. + r_args[kParamIdx][ii] = param; + if (grad_scale_ptr) { + r_args[kGradIdx][ii] = grad_to_store; + } + r_args[kExpAvgIdx][ii] = exp_avg; + r_args[kExpAvgSqIdx][ii] = exp_avg_sq; + if (amsgrad) { + r_args[kMaxExpAvgSqIdx][ii] = max_exp_avg_sq; + } + } +} + +// [note: Conditional Gradient Store when `optimizer.step` is called by +// GradScaler] When a user is training their model(s) with an FP16 AMP recipe, +// parameter updates are done via `grad_scaler.step(optimizer)` instead of +// `optimizer.step()`. For most optimizers, GradScaler unscales gradients on +// behalf of those optimizers. Also, before `.step`, it makes sure that all the +// gradients involved are finite, which incurs a device sync. On the other hand, +// fused optimizers set their member variable of `_step_supports_amp_scaling` to +// `True` in order to remove the device sync above. This means that fused +// optimizers have to have their CUDA kernels (a) unscale gradients and (b) skip +// parameter updates accordingly. To be functionally on par with `torch.optim` +// optimizers and `_multi_tensor` ones, the kernel below writes out gradients +// only when `grad_scale_ptr != nullptr. +template +struct FusedAdamMathFunctor { + static_assert( + depth == 4 || depth == 5, + "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); + using opmath_t = at::opmath_type; + C10_DEVICE __forceinline__ void operator()( + int chunk_size, + FusedOptimizerTensorListMetadata& tl, + const float* lr_ptr, + const double& lr, + const double& beta1, + const double& beta2, + const double& weight_decay, + const double& eps, + const bool& maximize, + const float* grad_scale_ptr, + const float* found_inf_ptr) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + const double lr_double = lr_ptr ? *lr_ptr : lr; + + if (found_inf_ptr && *found_inf_ptr == 1) { + return; + } + const auto [bias_correction1, bias_correction2_sqrt] = + [&]() -> std::pair { + auto* step_count = + reinterpret_cast(tl.state_steps_addresses[tensor_loc]); + const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count); + const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count); + const auto bias_correction2_sqrt = std::sqrt(bias_correction2); + return {bias_correction1, bias_correction2_sqrt}; + }(); + + scalar_type* args[depth]; + scalar_type r_args[depth][kILP]; + const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; + + const bool all_aligned{ + init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; + if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (int i = 0; i < depth; i++) { + load_store(r_args[i], args[i], 0, i_start); + } + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); +#pragma unroll + for (int i = 0; i < depth; i++) { + if (i != kGradIdx || grad_scale_ptr) { + load_store(args[i], r_args[i], i_start, 0); + } + } + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); +#pragma unroll + for (int i = 0; i < depth; i++) { + if (i != kGradIdx || grad_scale_ptr) { + store_args(args[i], r_args[i], i_start, chunk_size, n); + } + } + } + } + } +}; +} // namespace + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..18d9baa6200ff10d2b9df491269f5e779fe93969 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adamw_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cae11356dd3c1df9c4924a20f9a984981bbcff2e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace at::native { + +void _fused_adamw_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/im2col.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/im2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6b694cee3fbdf1a3bf9befb258e2cca8664f5eb3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/im2col.cuh @@ -0,0 +1,336 @@ +#pragma once + +#include +#include +#include + +#include + +namespace at::native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy +// (borrowed from Caffe: +// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu) +// CUDA_NUM_THREADS = 1024 + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void im2col_kernel( + const int64_t n, + const dt* data_im, + const int64_t height, + const int64_t width, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_col) { + CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) { + int64_t w_out = index % width_col; + + int64_t idx = index / width_col; + + int64_t h_out = idx % height_col; + int64_t channel_in = idx / height_col; + int64_t channel_out = channel_in * kernel_height * kernel_width; + int64_t h_in = h_out * stride_height - pad_height; + int64_t w_in = w_out * stride_width - pad_width; + + dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out; + const dt* im = data_im + (channel_in * height + h_in) * width + w_in; + + for (int64_t i = 0; i < kernel_height; ++i) { + for (int64_t j = 0; j < kernel_width; ++j) { + int64_t h = h_in + i * dilation_height; + int64_t w = w_in + j * dilation_width; + *col = (h >= 0 && w >= 0 && h < height && w < width) + ? im[i * dilation_height * width + j * dilation_width] + : static_cast
(0); + col += height_col * width_col; + } + } + } +} + +template +void im2col( + cudaStream_t stream, + const dt* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int64_t num_kernels = channels * height_col * width_col; + // Launch CUDA_NUM_THREADS = 1024 + im2col_kernel<<>>( + num_kernels, + data_im, + height, + width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__forceinline__ __device__ void col2im_device( + const int64_t index, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + accT val = static_cast(0); + const int64_t w_im = index % width + pad_width; + const int64_t h_im = (index / width) % height + pad_height; + const int64_t c_im = index / (width * height); + int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1; + int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1; + // compute the start and end of the output + const int64_t w_col_start = (w_im < kernel_extent_w) + ? 0 + : (w_im - kernel_extent_w) / stride_width + 1; + const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col); + const int64_t h_col_start = (h_im < kernel_extent_h) + ? 0 + : (h_im - kernel_extent_h) / stride_height + 1; + const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col); + + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int64_t h_k = (h_im - h_col * stride_height); + int64_t w_k = (w_im - w_col * stride_width); + if (h_k % dilation_height == 0 && w_k % dilation_width == 0) { + h_k /= dilation_height; + w_k /= dilation_width; + int64_t data_col_index = + (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = static_cast
(val); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_kernel( + const int64_t n, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + CUDA_KERNEL_LOOP(index, n) { + col2im_device( + index, + data_col, + height, + width, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + } +} + +template +void col2im( + cudaStream_t stream, + const dt* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im) { + int64_t num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_kernel + <<>>( + num_kernels, + data_col, + height, + width, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_batched_kernel( + const int64_t n, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im, + const int64_t im_batch_stride) { + using accT = at::acc_type; + const auto im_numel = n * nbatch; + + CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) { + const auto ibatch = index / n; + const auto slice_index = index % n; + + col2im_device( + slice_index, + data_col + ibatch * col_batch_stride, + height, + width, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im + ibatch * im_batch_stride); + } +} + +template +void col2im_batched( + cudaStream_t stream, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im, + const int64_t im_batch_stride) { + const int64_t num_kernels = channels * height * width; + const int64_t output_numel = nbatch * num_kernels; + if (output_numel == 0) { + return; // No work to do + } + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_batched_kernel<<>>( + num_kernels, + data_col, + col_batch_stride, + nbatch, + height, + width, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im, + im_batch_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/jit_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/jit_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7961af63cced26b93a1fcadba1a6ade25f1adb0c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/jit_utils.h @@ -0,0 +1,249 @@ +#pragma once + +#include + +#include +#include +#include + +namespace at::cuda::jit { + +enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar}; + +struct NvrtcFunction { + CUmodule module = CUmodule(); + CUfunction function = nullptr; +}; + +struct KernelDescriptor { + std::string name; + std::string f; + c10::ScalarType f_inputs_type; + c10::ScalarType result_type; + c10::SmallVector extra_args_types; + int nInputs, nOutputs; +}; + +// Helper function to return a vector +// corresponding to the type of the arguments in parameter pack. +template +c10::SmallVector get_extra_args_types() { + return {c10::CppTypeToScalarType::value ...}; +} + +template < + typename result_type, + typename f_inputs_type, + typename... ExtraArgs> +KernelDescriptor make_kernel_descriptor( + std::string name, + std::string f, + int nInputs, + int nOutputs) { + KernelDescriptor ret; + ret.name = std::move(name); + ret.f = std::move(f); + ret.f_inputs_type = c10::CppTypeToScalarType::value; + ret.result_type = c10::CppTypeToScalarType::value; + ret.extra_args_types = get_extra_args_types(); + ret.nInputs = nInputs; + ret.nOutputs = nOutputs; + return ret; +} + +inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { + auto ip = reinterpret_cast(pointer); +#ifdef USE_ROCM + if ((default_alignment == 1) && (ip % (16 * default_alignment) == 0)) { + return 16; + } + if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) { + return 8; + } +#else + if (ip % (8 * default_alignment) == 0) { + return 8; + } +#endif + if (ip % (4 * default_alignment) == 0) { + return 4; + } + if (ip % (2 * default_alignment) == 0) { + return 2; + } + return 1; +} + +inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef pointers) { + TORCH_INTERNAL_ASSERT(desc.nOutputs == 1); + TORCH_INTERNAL_ASSERT(static_cast(pointers.size()) == 1 + desc.nInputs); + + // Deals with output + auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize(); + auto result = can_vectorize_up_to(result_size, pointers[0]); + + // Incorporates input(s) + auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + for (auto i : c10::irange(1, pointers.size())) { + result = std::min(result, can_vectorize_up_to(input_size, pointers[i])); + } + + return result; +} + +//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh +#ifdef USE_ROCM +#define JIT_THREAD_WORK_SIZE 4 +#else +#define JIT_THREAD_WORK_SIZE 8 +#endif + +int calc_io_size( + const int nInputs, + const int nOutputs, + const c10::ScalarType& inputs_type, + const c10::ScalarType& result_type); + +int calc_thread_work_size( + const int nInputs, + const int nOutputs, + const c10::ScalarType& inputs_type, + const c10::ScalarType& result_type); + +std::string generate_code( + int nInputs, + int nOutputs, + const std::string& func, + const std::string& name, + const std::string& f_inputs_type, + const std::string& compute_type, + const std::string& result_type, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + c10::SmallVector& extra_args_typenames, + int thread_work_size=JIT_THREAD_WORK_SIZE, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + int thread_work_size=JIT_THREAD_WORK_SIZE, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_reduction_code( + int nOutputs, + const std::string& func, + const std::string& name, + const int vt0, + const std::string& f_inputs_type, + const std::string& reduction_accum_type, + const std::string& result_type, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +std::string generate_reduction_code( + const KernelDescriptor &desc, + const int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +NvrtcFunction jit_pwise_function( + const std::string& code, + const std::string& kernel_name); + +void launch_jitted_pwise_function( + NvrtcFunction function, + const void* args[], + const dim3 nBlocks, + const dim3 kBlockSize, + const int smem=0); + +template +struct delayed_false : std::false_type { +}; + +// Defines type names +// NOTE: General case is instantiated only for invalid types. +// All the valid types have specialization using the TYPE_NAME_FN +// macro below. +template +inline std::string typeName() { + // we can't use static_assert(false) directly as the + // program will be not compiled even if the template is not + // instantiated, so we use `delayed_false` + // to make sure compiler doesn't eagerly raise + // fail this assertion. + static_assert(delayed_false::value, "invalid type for jiterator"); + return "void"; +} + +#define TYPE_NAME_FN(ctype, name) \ +template <> inline std::string typeName(){ \ + return std::string(#ctype); \ +} + +AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN) +#undef TYPE_NAME_FN +// JIT uses std::complex directly, because nvRTC compile programs +// with -default-device, so there is no such issue like: +// "std::sin(complex) is __host__ only" +template <> inline std::string typeName(){ + return "bool"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName(){ + return "at::Half"; +} +template <> inline std::string typeName(){ + return "at::BFloat16"; +} +template <> inline std::string typeName(){ + return "at::Float8_e5m2"; +} +template <> inline std::string typeName(){ + return "at::Float8_e4m3fn"; +} +template <> inline std::string typeName() { + return "at::Float8_e5m2fnuz"; +} +template <> inline std::string typeName() { + return "at::Float8_e4m3fnuz"; +} +template <> inline std::string typeName() { + // TODO(#146647): Can the code here be made generic for any scalartype? + return "at::Float8_e8m0fnu"; +} + +#define TYPE_NAME_CASE(ctype, scalartype) \ + case ScalarType::scalartype: return typeName(); +inline std::string typeName(ScalarType t) { + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE) + default: + TORCH_CHECK(false, "invalid type for jiterator"); + } +} +#undef TYPE_NAME_CASE + +TORCH_CUDA_CPP_API void initializeCudaContext(); + +} // namespace at::cuda::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh new file mode 100644 index 0000000000000000000000000000000000000000..98c4639682477f497f30e64b4b184e6211a20c1c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh @@ -0,0 +1,680 @@ +namespace at::cuda { +//windows doesn't like large string literals, so split in two +const std::string reduction_template_0 = R"ESCAPE( + #define C10_HOST_DEVICE __host__ __device__ + #define C10_DEVICE __device__ + #if defined(__clang__) && defined(__HIP__) + #ifndef __forceinline__ + #define __forceinline__ inline __attribute__((always_inline)) + #endif + // until ROCm support for kernel asserts is restored + #define assert(expr) (static_cast(0)) + #endif + + template + __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + #if defined(__clang__) && defined(__HIP__) + return __shfl_down(value, delta, width); + #else + return __shfl_down_sync(mask, value, delta, width); + #endif + } + + + #if ${complex} + template + __device__ __forceinline__ std::complex WARP_SHFL_DOWN(std::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + return std::complex( + #if defined(__clang__) && defined(__HIP__) + __shfl_down(value.real(), delta, width), + __shfl_down(value.imag(), delta, width)); + #else + __shfl_down_sync(mask, value.real(), delta, width), + __shfl_down_sync(mask, value.imag(), delta, width)); + #endif + } + #endif + + // aligned vector generates vectorized load/store on CUDA + template + struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; + }; + + + C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; + } + + + + + struct ReduceConfig { + //has to match host-side ReduceConfig in the eager code + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + static constexpr int input_vec_size = 4; + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + + }; + + +//TODO this will need to be different for more generic reduction functions +namespace reducer { + + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + + inline __device__ ${functor} + + inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + + inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + // wrap a normal reduction that ignores the index + inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) { + return combine(acc, val); + } +} + + +struct ReduceJitOp { + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + using InputCalculator = OffsetCalculator<1>; + using OutputCalculator = OffsetCalculator<2>; + +// static constexpr bool can_accumulate_in_output = +// std::is_convertible_v +// && std::is_convertible_v; + + static constexpr int input_vec_size = ReduceConfig::input_vec_size; + + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + uint32_t output_idx = config.output_idx<${output_vec_size}>(); + uint32_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = Array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + + value = thread_reduce<${output_vec_size}>(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce<${output_vec_size}>(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce<${output_vec_size}>(value, shared_memory); + } + + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce<${output_vec_size}>(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output<${output_vec_size}>(out, value); + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE Array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + assert(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](uint32_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](uint32_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + uint32_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = aligned_vector; + + uint32_t idx = config.input_idx(); + const uint32_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + scalar_t values[input_vec_size]; + + load_t *values_vector = reinterpret_cast(&values[0]); + + while (idx * input_vec_size + input_vec_size - 1 < end) { + *values_vector = reinterpret_cast(data)[idx]; + #pragma unroll + for (uint32_t i = 0; i < input_vec_size; i++) { + value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + uint32_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = reducer::combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE Array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + uint32_t idx = config.input_idx(); + const uint32_t end = config.num_inputs; + const uint32_t stride = config.step_input; + const int vt0=${vt0}; + + using arg_vec_t = Array; + using load_t = aligned_vector; + const load_t* data = reinterpret_cast(data_); + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + values[i] = data[calc(idx + i * stride) / output_vec_size]; + } + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + values[i] = data[calc(idx) / output_vec_size]; + idx += stride; + } + idx = idx_; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + template + C10_DEVICE Array block_x_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = reducer::warp_shfl_down(value[i], offset); + value[i] = reducer::combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE Array block_y_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + )ESCAPE"; + + const std::string reduction_template_1 = R"ESCAPE( + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE Array accumulate_in_output( + Array out, + Array value + ) const { + Array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = reducer::combine(*(out[i]), value[i]); + } + return ret; + } + + + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value + ) const { + assert(!final_output); + return (out_scalar_t)value; + } + + template + C10_DEVICE void set_results(const T x, const uint32_t base_offset) const { + assert(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + +//TODO - multi-output reduction - we won't be able to use thrust::pair +//just explicitly specify typed output reads/writes +//Currently implemented for max of two outputs +// template +// C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { +// if (noutputs >= 1) { +// auto res0 = (T1*)((char*)dst[0] + base_offset); +// *res0 = x.first; +// } +// if (noutputs >= 2) { +// // base offset is computed assuming element size being sizeof(T1), so we need to make a +// // correction to obtain the correct base offset +// auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); +// *res1 = x.second; +// } +// } + + template + C10_DEVICE void set_results_to_output(Array value, Array base_offset) const { + assert(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(reducer::project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE Array global_reduce(Array value, Array *acc, char* shared_memory) const { + using arg_vec_t = Array; + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + uint32_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + uint32_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + __threadfence(); //complete acquire pattern + value = ident; + if (config.should_block_x_reduce()) { + uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + uint32_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } else { + uint32_t input_offset = threadIdx.y; + uint32_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +extern "C" +__launch_bounds__(${max_threads_lb}, 4) +__global__ void reduction_${name}_kernel(ReduceJitOp r){ + r.run(); +} +)ESCAPE"; + +const std::string reduction_template = reduction_template_0 + reduction_template_1; + + +const std::string &get_reduction_template() { + return reduction_template; +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/thread_constants.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/thread_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..bcc797a26e1ce53032922d5654a4b771209ceca7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/thread_constants.h @@ -0,0 +1,25 @@ +#pragma once +#include + +// Marks a lambda as executable on both the host and device. The __host__ +// attribute is important so that we can access static type information from +// the host, even if the function is typically only executed on the device. +#ifndef GPU_LAMBDA +#define GPU_LAMBDA __host__ __device__ +#endif + +#if defined(USE_ROCM) +constexpr int num_threads() { + return 256; +} + +constexpr int thread_work_size() { return 4; } +#else +constexpr uint32_t num_threads() { + return C10_WARP_SIZE * 4; +} + +constexpr int thread_work_size() { return 8; } +#endif + +constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/vol2col.cuh b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/vol2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e69463fa42224c31abd4f2446f62bb16d81e4fb0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/vol2col.cuh @@ -0,0 +1,262 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at::native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy on volumes +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void vol2col_kernel( + const int64_t n, + const T* data_vol, + const int depth, + const int height, + const int width, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + const int depth_col, + const int height_col, + const int width_col, + T* data_col) { + CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) { + auto w_out = index % width_col; + index /= width_col; + auto h_out = index % height_col; + index /= height_col; + auto t_out = index % depth_col; + auto channel_in = index / depth_col; + auto channel_out = channel_in * ksize_t * ksize_h * ksize_w; + auto t_in = t_out * stride_t - pad_t; + auto h_in = h_out * stride_h - pad_h; + auto w_in = w_out * stride_w - pad_w; + data_col += + ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + + w_out; + data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in; + for (int i = 0; i < ksize_t; ++i) { + for (int j = 0; j < ksize_h; ++j) { + for (int k = 0; k < ksize_w; ++k) { + auto t = t_in + i * dilation_t; + auto h = h_in + j * dilation_h; + auto w = w_in + k * dilation_w; + *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && + w < width) + ? data_vol + [i * dilation_t * height * width + j * dilation_h * width + + k * dilation_w] + : static_cast(0); + data_col += depth_col * height_col * width_col; + } + } + } + } +} + +template +void vol2col( + cudaStream_t stream, + const T* data_vol, + const int channels, + const int depth, + const int height, + const int width, + const int depth_col, + const int height_col, + const int width_col, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + T* data_col) { + // We are going to launch channels * depth_col * height_col * width_col + // kernels, each kernel responsible for copying a single-channel grid. + // We cast an operand to int64 so that the product will not overflow + const auto num_kernels = static_cast(channels) * depth_col * height_col * width_col; + // Launch + vol2col_kernel<<>>( + num_kernels, + data_vol, + depth, + height, + width, + ksize_t, + ksize_h, + ksize_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + depth_col, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__global__ void vol2im_kernel( + const int64_t n, + const T* data_col, + const unsigned depth, + const unsigned height, + const unsigned width, + const unsigned channels, + const unsigned kernel_t, + const unsigned kernel_h, + const unsigned kernel_w, + const unsigned pad_t, + const unsigned pad_h, + const unsigned pad_w, + const unsigned stride_t, + const unsigned stride_h, + const unsigned stride_w, + const unsigned dilation_t, + const unsigned dilation_h, + const unsigned dilation_w, + const unsigned depth_col, + const unsigned height_col, + const unsigned width_col, + T* data_vol) { + CUDA_KERNEL_LOOP(index, n) { + accT val = static_cast(0); + const auto w_im = index % width + pad_w; + const auto h_im = (index / width) % height + pad_h; + const auto t_im = (index / width / height) % depth + pad_t; + const auto c_im = index / (width * height * depth); + auto kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + auto kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + auto kernel_extent_t = (kernel_t - 1) * dilation_t + 1; + // compute the start and end of the output + const auto w_col_start = + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const auto w_col_end = std::min(w_im / stride_w + 1, width_col); + const auto h_col_start = + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const auto h_col_end = std::min(h_im / stride_h + 1, height_col); + const auto t_col_start = + (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1; + const auto t_col_end = std::min(t_im / stride_t + 1, depth_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) { + for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) { + uint64_t t_k = (t_im - t_col * stride_t); + uint64_t h_k = (h_im - h_col * stride_h); + uint64_t w_k = (w_im - w_col * stride_w); + if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && + w_k % dilation_w == 0) { + t_k /= dilation_t; + h_k /= dilation_h; + w_k /= dilation_w; + const int64_t idx_k = + ((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k; + const int64_t data_col_index = + ((idx_k * depth_col + t_col) * + height_col + h_col) * + width_col + w_col; + val += data_col[data_col_index]; + } + } + } + } + data_vol[index] = static_cast(val); + } +} + +template +void col2vol( + cudaStream_t stream, + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t output_depth, + const int64_t output_height, + const int64_t output_width, + const int64_t patch_t, + const int64_t patch_h, + const int64_t patch_w, + const int64_t pad_t, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_t, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_t, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_vol) { + const auto num_kernels = channels * depth * height * width; + + auto check_fits_in_unsigned = + [](int64_t val, const char * name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK(val >= 0 && val <= umax, + name, " must fit in a 32-bit unsigned value"); + }; + check_fits_in_unsigned(num_kernels, "input size"); + check_fits_in_unsigned( + channels * patch_t * patch_h * patch_w, "channels x kernel size"); + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + vol2im_kernel + <<>>( + num_kernels, + data_col, + depth, + height, + width, + channels, + patch_t, + patch_h, + patch_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + output_depth, + output_height, + output_width, + data_vol); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_collection.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_collection.h new file mode 100644 index 0000000000000000000000000000000000000000..b7fa3dcc72716f207137443856f7cc1f3f840823 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_collection.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace at::native { +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v4(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_16x16x1_16x16x1_1x16x1x16_4_Intrawave_v4(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_4x64x1_4x64x1_1x16x1x16_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x256x32_32x32_4x4_4x64x1_4x64x1_1x16x1x16_4_Intrawave_v5(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_224x256x64_16x16_7x8_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x224x64_16x16_8x7_8x32x1_8x32x1_1x32x1x8_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v3(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v5(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_128x128x64_32x32_2x2_8x32x1_8x32x1_1x16x1x16_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_32x16x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4_Intrawave_v1(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x16x64_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x16x64_16x16_4x1_16x16x1_16x8x1_1x32x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_256x16x64_16x16_4x1_32x8x1_32x4x1_1x32x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_128x16x64_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_64x16x64_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_32x16x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_64_16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x32x64_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x64x64_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_128_16x128x64_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +void bgemm_kernel_bf16bf16bf16_256_16x256x64_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4_Intrawave_v2(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); + +}; // namespace at::native \ No newline at end of file diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h new file mode 100644 index 0000000000000000000000000000000000000000..7cf35e13349ffccbb3ce81e458ddafc14b381daa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h @@ -0,0 +1,164 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// Define commonly used types. +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using AccDataType = F32; +using DsDataType = ck::Tuple<>; +using CDataType = BF16; +using CShuffleDataType = BF16; +using DsLayout = ck::Tuple<>; +using CLayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template < + typename A_DATA_TYPE, + typename B_DATA_TYPE, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int AK1, + int BK1, + int WAVE_TILE_M, + int WAVE_TILE_N, + int WAVE_MAP_M, + int WAVE_MAP_N, + typename ABLOCK_TRANSFER, + int ABLOCK_TRANSFER_SSPV, + int ABLOCK_TRANSFER_DSPV_K1, + typename BBLOCK_TRANSFER, + int BBLOCK_TRANSFER_SSPV, + int BBLOCK_TRANSFER_SSPV_K1, + int CSHUFFLE_MXDL_PWPS, + int CSHUFFLE_NXDL_PWPS, + typename CSHUFFLEBLOCK_TRANSFER, + typename CDESHUFFLEBLOCK_TRANSFER, + ck::BlockGemmPipelineScheduler LOOP_SCHED, + ck::BlockGemmPipelineVersion PIPELINE_VERSION, + ck::tensor_operation::device::GemmSpecialization GEMM_SPEC = + ck::tensor_operation::device::GemmSpecialization::MNPadding, + bool TRANSA = false, + bool TRANSB = false> +void bgemm_kernel_impl(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< + ALayout, // ALayout + BLayout, // BLayout + DsLayout, // DsLayout + CLayout, // CLayout + ADataType, // ADataType + BDataType, // BDataType + DsDataType, // DsDataType + CDataType, // CDataType + AccDataType, // AccDataType + CShuffleDataType, // CshuffleType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CDEElementOp, // CElementwiseOperation + GEMM_SPEC, // GEMMSpecialization + BLOCK_SIZE, // BlockSize + MBLOCK, // MPerBlock + NBLOCK, // NPerBlock + KBLOCK, // KPerBlock + AK1, // AK1 + BK1, // BK1 + WAVE_TILE_M, // MPerXDL + WAVE_TILE_N, // NPerXDL + WAVE_MAP_M, // MXdlPerWave + WAVE_MAP_N, // NXdlPerWave + ABLOCK_TRANSFER, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + ABLOCK_TRANSFER_SSPV, // ABlockTransferSrcScalarPerVector + ABLOCK_TRANSFER_DSPV_K1, // ABlockTransferDstScalarPerVector_AK1 + 0, // ABlockLdsExtraM + BBLOCK_TRANSFER, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + BBLOCK_TRANSFER_SSPV, // BBlockTransferSrcScalarPerVector + BBLOCK_TRANSFER_SSPV_K1, // BBlockTransferDstScalarPerVector_BK1 + 0, // BBlockLdsAddExtraN + CSHUFFLE_MXDL_PWPS, // CShuffleMXdlPerWavePerShuffle + CSHUFFLE_NXDL_PWPS, // CShuffleNXdlPerWavePerShuffle + CSHUFFLEBLOCK_TRANSFER, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + CDESHUFFLEBLOCK_TRANSFER, // CDEShuffleBlockTransferScalarPerVectors + LOOP_SCHED, // BlockGemmPipelineScheduler + PIPELINE_VERSION // BlockGemmPipelineVersion + >{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( + b, // A and B are swapped for CK + a, + {}, + c, + n, + m, + k, + num_batches, + ldb, + lda, + {}, + ldc, + n * k, // batch_stride_a + m * k, // batch_stride_b + {}, + m * n, // batch_stride_c + a_element_op, + b_element_op, + cde_element_op + ); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); +} + +}; // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_bgemm.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_bgemm.h new file mode 100644 index 0000000000000000000000000000000000000000..1e71d3079c906214d490c4d81d126603090c1301 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_bgemm.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace at::native { + +template +inline void bgemm_internal_ck(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas_bgemm_internal_ck: not implemented"); +} + +template <> +void bgemm_internal_ck(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_gemm.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..176cbabd5e01c40189b2d6115136d27a7e5865b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_gemm.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +namespace at::native { + + +template +inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); +} + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); + + + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_gemm_template.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_gemm_template.h new file mode 100644 index 0000000000000000000000000000000000000000..b34a8b132674abbe126b46e480a7d7ed7c489a22 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/hip/ck_gemm_template.h @@ -0,0 +1,416 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// Define commonly used types. +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace at::native { + +// Elementwise Operators +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(C& c, const AB& ab) const; + + template<> + __host__ __device__ constexpr void operator() + (float& c, const float& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::bhalf_t& c, const ck::bhalf_t& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::half_t& c, const ck::half_t& ab) const + { + c = alpha_ * ab; + }; + + float alpha_; + // TODO: Leaving for now, will use later + float beta_; +}; + +template < + typename Dtype, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int AK1, + int BK1, + int MPER_XDL, + int NPER_XDL, + int MPER_WAVE, + int NPER_WAVE, + typename ABLOCK_CLUSTER_LENS, + typename ABLOCK_CLUSTER_ORDER, + typename ABLOCK_SRC_ORDER, + int ABLOCK_VECTOR_DIM, + int ABLOCK_SCALAR_VEC, + int ABLOCK_SCALAR_VEC_AK1, + bool ABLOCK_LDS_EXTRAM, + typename BBLOCK_CLUSTER_LENS, + typename BBLOCK_CLUSTER_ORDER, + typename BBLOCK_SRC_ORDER, + int BBLOCK_VECTOR_DIM, + int BBLOCK_SCALAR_VEC, + int BBLOCK_SCALAR_VEC_AK1, + bool BBLOCK_LDS_EXTRAN, + int CMPER_WAVE, + int CNPER_WAVE, + typename BLOCK_CLUSTER_LENS, + typename CDE_SCALAR_VEC, + bool PADDING = false, + bool TRANSA = false, + bool TRANSB = false> +void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // Get input information. + int M = m; + int N = n; + int K = k; + + int StrideA = lda; + int StrideB = ldb; + int StrideC = ldc; + + int KBatch = 1; + + float falpha = alpha; + float fbeta = beta; + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + using CDataType = typename CkMathType::dtype; + using DDataType = typename CkMathType::dtype; + + using AccDataType = float; + using CShuffleDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + using DLayout = Row; + using CLayout = Row; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = AlphaBetaAdd; + + + static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; + + + using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmSpec, + BLOCK_SIZE, + MBLOCK, + NBLOCK, + KBLOCK, + AK1, + BK1, + MPER_XDL, + NPER_XDL, + MPER_WAVE, + NPER_WAVE, + ABLOCK_CLUSTER_LENS, + ABLOCK_CLUSTER_ORDER, + ABLOCK_SRC_ORDER, + ABLOCK_VECTOR_DIM, + ABLOCK_SCALAR_VEC, + ABLOCK_SCALAR_VEC_AK1, + ABLOCK_LDS_EXTRAM, + BBLOCK_CLUSTER_LENS, + BBLOCK_CLUSTER_ORDER, + BBLOCK_SRC_ORDER, + BBLOCK_VECTOR_DIM, + BBLOCK_SCALAR_VEC, + BBLOCK_SCALAR_VEC_AK1, + BBLOCK_LDS_EXTRAN, + CMPER_WAVE, + CNPER_WAVE, + BLOCK_CLUSTER_LENS, + CDE_SCALAR_VEC>; + + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{alpha, beta}; + + + using DDataArrayType = std::array; + DDataArrayType DDataArray; + + // We swap A and B inputs here as a temporary workaround + auto argument = gemm.MakeArgument( + reinterpret_cast(b), + reinterpret_cast(a), + DDataArray, + reinterpret_cast(c), + N, + M, + K, + StrideB, + StrideA, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); +} + + +template < + typename Dtype, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int K1, + int MPER_WMMA, + int NPER_WMMA, + int MPER_WAVE, + int NPER_WAVE, + typename ABLOCK_CLUSTER_LENS, + typename ABLOCK_CLUSTER_ORDER, + typename ABLOCK_SRC_ORDER, + int ABLOCK_VECTOR_DIM, + int ABLOCK_SCALAR_VEC, + int ABLOCK_SCALAR_VEC_K1, + bool ABLOCK_LDS_EXTRAM, + typename BBLOCK_CLUSTER_LENS, + typename BBLOCK_CLUSTER_ORDER, + typename BBLOCK_SRC_ORDER, + int BBLOCK_VECTOR_DIM, + int BBLOCK_SCALAR_VEC, + int BBLOCK_SCALAR_VEC_AK1, + bool BBLOCK_LDS_EXTRAN, + int CMPER_WAVE, + int CNPER_WAVE, + typename CBLOCK_CLUSTER_LENS, + int CNPER_BLOCK, + bool PADDING = false, + bool TRANSA = false, + bool TRANSB = false> +void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // Get input information. + int M = m; + int N = n; + int K = k; + + int StrideA = lda; + int StrideB = ldb; + int StrideC = ldc; + + int KBatch = 1; + + float falpha = alpha; + float fbeta = beta; + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + using CDataType = typename CkMathType::dtype; + using DDataType = typename CkMathType::dtype; + + using AccDataType = float; + using CShuffleDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + using DLayout = Row; + using CLayout = Row; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = PassThrough; + + + static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; + + + using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemmWmma_CShuffle; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + + using DDataArrayType = std::array; + DDataArrayType DDataArray; + + // We swap A and B inputs here as a temporary workaround + auto argument = gemm.MakeArgument( + reinterpret_cast(b), + reinterpret_cast(a), + reinterpret_cast(c), + N, + M, + K, + StrideB, + StrideA, + StrideC, + b_element_op, + a_element_op, + c_element_op); + + + if(!gemm.IsSupportedArgument(argument)) + { + printf("error shape = %ld %ld %ld TRANSA=%d TRANSB=%d \n", + n, m, k,TRANSA, TRANSB); + + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + + auto stream = at::cuda::getCurrentHIPStream().stream(); +#if 1 + invoker.Run(argument, StreamConfig{stream, false}); +#else + float ave_time = invoker.Run(argument, StreamConfig{stream, true}); + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << N <<" " < +#include +#include +#include + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace at::native { + +template +struct CkMathType { + using dtype = T; +}; + +template <> +struct CkMathType { + using dtype = ck::bhalf_t; +}; + +template <> +struct CkMathType { + using dtype = ck::half_t; +}; + +template +struct CkTensorLayout { + // default goes to row-wise for now + using a_layout = Row; + using b_layout = Row; +}; + +// True denotes transpose is necessary. Default is Col, so return Row +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Col; +}; + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Col; +}; + +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Row; +}; + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Row; +}; + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_kernels.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..9b522d7f7705ac122d38855f3cc12f270c637093 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_kernels.h @@ -0,0 +1,42 @@ +#pragma once +#include +#include +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +/** + * @brief Rearranges the quantized weight to support kleidiai inference + * @param bl Groupsize for quantization should be multiple of 32 + */ +void kai_pack_int4_rhs( + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k, + const int64_t bl); + +/** + * @brief Outputs the buffer size for the packed weights + * @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise + */ +size_t kai_pack_rhs_int4_size( + const int64_t n, + const int64_t k, + const int64_t bl); + +/** + * @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul ) + */ +void kai_quant_pack_lhs_int4_mm( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t bl); +} // namespace at::native::kleidiai +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_pack.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_pack.h new file mode 100644 index 0000000000000000000000000000000000000000..4ff3371ab5e2ac68a0f6a982f9d2f4f235449adb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_pack.h @@ -0,0 +1,106 @@ +#pragma once +#include +#include +#include +#include +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +template +void kai_pack_rhs_groupwise_int4( + T& kernel, + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k, + const int64_t bl, + const int64_t rhs_stride, + const int64_t scale_stride) { + const auto& ukernel = kernel.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + auto weight_packed_data = + reinterpret_cast(weight_packed.data_ptr()); + const auto weight_data = weight.data_ptr(); + auto scales_data = scales.const_data_ptr(); + + if (weight_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); + } + + if (scales_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); + } + + float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + auto& params = kernel.rhs_pack_params; + + kernel.kai_run_rhs_pack( + /*num_groups=*/1, + n, + k, + nr, + kr, + sr, + bl, + (const uint8_t*)(weight_data), + rhs_stride, + bias_ptr, + scales_data, + scale_stride, + weight_packed_data, + 0, + ¶ms); +} + +template +void kai_pack_rhs_channelwise_int4( + T& kernel, + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k) { + const auto& ukernel = kernel.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + auto weight_packed_data = + reinterpret_cast(weight_packed.data_ptr()); + const auto weight_data = weight.data_ptr(); + const auto scales_data = scales.data_ptr(); + + if (weight_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); + } + + if (scales_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); + } + + float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + auto& params = kernel.rhs_pack_params; + + kernel.kai_run_rhs_pack( + /*num_groups=*/1, + n, + k, + nr, + kr, + sr, + (const uint8_t*)(weight_data), + (const float*)(bias_ptr), + (const float*)(scales_data), + weight_packed_data, + 0, + ¶ms); +} + +} // namespace at::native::kleidiai + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_ukernel_interface.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_ukernel_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..8480469cdea8638ce140561d03344ef989dd1de2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/kleidiai/kai_ukernel_interface.h @@ -0,0 +1,144 @@ +#pragma once +#include +#if AT_KLEIDIAI_ENABLED() +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native::kleidiai { + +enum class kai_kernel_id { + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod = + 0, // Groupwise 4 bit GEMV + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm = + 1, // Groupwise 4 bit GEMM + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod = + 2, // Channelwise 4 bit GEMV + matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm = + 3 // Channelwise 4 bit GEMM +}; + +// Channelwise Kernel mapping +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const float* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* rhs, + const float* bias, + const float* scale, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + + kai_matmul_ukernel_f32_qa8dxp_qs4cxp( + const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {} +}; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp +kai_select_channelwise_matmul_ukernel(const kai_kernel_id id); + +// Groupwise Kernel mapping +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { + struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + enum kai_datatype scale_dt); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const float* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + const uint8_t* rhs, + size_t rhs_stride, + const float* bias, + const void* scale, + size_t scale_stride, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); + + kai_matmul_ukernel_f32_qa8dxp_qs4c32p( + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {} +}; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( + const kai_kernel_id id); + +} // namespace at::native::kleidiai +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/Conv.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/Conv.h new file mode 100644 index 0000000000000000000000000000000000000000..31ddf4e89aa51c677e0949899d79c3de70e89a36 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/Conv.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include + +#if AT_MKLDNN_ENABLED() + +namespace at::native::xpu { +C10_API Tensor convolution_pointwise( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view attr, + torch::List> scalars, + std::optional algorithm); + +C10_API Tensor convolution_pointwise_binary( + const Tensor& input_t, + const Tensor& other_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +C10_API Tensor& convolution_pointwise_binary_( + Tensor& other_t, + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +} // namespace at::native::xpu + +#endif // AT_MKLDNN_ENABLED() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/FusionUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/FusionUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..b8b4e1cccf0ac67e507c69e205bcc36012e064bb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/FusionUtils.h @@ -0,0 +1,53 @@ +#pragma once +#include + +// +// This header file provides utility functions for constructing and managing +// oneDNN attributes used in fusion operations on XPU devices. These utilities +// include functions for creating unary and binary post-operations attributes, +// as well as mapping string representations of operations to oneDNN attributes. +// + +namespace at::native::xpu { +at::native::onednn::Attr& unary_attr_with_arg( + onednn::Attr& attr, + std::string_view unary, + torch::List> scalars, + std::optional algorithm); + +at::native::onednn::Attr& string_to_unary_attr( + onednn::Attr& attr, + std::string_view unary); + +at::native::onednn::Attr& construct_unary_attr( + onednn::Attr& attr, + std::string_view unary, + torch::List> scalars, + std::optional algorithm); + +template +onednn::Attr& construct_binary_attr( + onednn::Attr& attr, + std::string_view binary, + const Tensor& other) { + if (binary == "mul") { + attr.append_post_binary(attr.kind_with_binary_mul, other); + } else if (binary == "sub") { + attr.append_post_binary(attr.kind_with_binary_sub, other); + } else if (binary == "div") { + attr.append_post_binary(attr.kind_with_binary_div, other); + } else if (binary == "add") { + attr.append_post_binary(attr.kind_with_binary_add, other); + } else if (binary == "sum") { + attr.append_post_sum(1.f, 1.f, 0); + } else { + TORCH_CHECK( + binary == "none", + "Binary attr ", + binary, + "is not supported for conv/linear post binary fusion"); + } + return attr; +} + +} // namespace at::native::xpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Attr.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Attr.h new file mode 100644 index 0000000000000000000000000000000000000000..eb09d37c4b7559d53faa1948654016426fb7d925 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Attr.h @@ -0,0 +1,463 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::native::onednn { +/* oneDNN quantization usage: + https://oneapi-src.github.io/oneDNN/dev_guide_attributes_quantization.html# + + src_fp32 = scale_src * (src_int8 - zero_point) + wei_fp32 = scale_wei * (wei_int8 - zero_point) + dst_fp32 = scale_dst * (dst_int8 - zero_point) + fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) + Int8 Convolution: dst_int8 = 1 / scale_dst * dst_fp32; + + Considering zero-point (asymmetric): + dst_fp32 = (src_int8 - src_zp) * src_sc * wei_int8 * wei_sc + dst_sc * (dst_int8 - dst_zp) = (src_int8 - src_zp) * wei_int8 * src_sc * + wei_sc + dst_int8 = (src_int8 - src_zp) * wei_int8 * src_sc * wei_sc / dst_sc + + dst_zp + + considering bias: + fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + bias + Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei) + + bias Int8 Convolution: dst_fp32 = (src_int8 * wei_int8 + bias/(scale_src * + scale_wei)) * (scale_src * scale_wei) Int8 Convolution: dst_int8 = 1 / + scale_dst * dst_fp32; +*/ + +/* + oneDNN postops usage: + Currently, oneDNN supports 5 kinds of post ops. More details can be refered +to oneDNN doc. + https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise + +0. without post ops + dst = Conv(src, wei) + bias; + dst_int8 = 1/q_scale * dst; q_scale is the op output quantization scale + fp32 API: Attr attr; + int8 API: Attr attr(q_scale); + +1. append eltwise post op + dst = elt_scale * Eltwise{conv_scale * [Conv(src, wei) + bias], alpha, beta} + dst_int8 = 1/q_scale * dst; + fp32 API: + Attr attr; + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) + int8 API: + Attr attr(q_scale); + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm) + +2. append sum post op + dst = conv_scale * Conv(src, wei) + sum_scale * (dst - zp) + dst_int8 = 1/q_scale * dst; + fp32 API: + Attr attr; + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_sum(sum_scale) + int8 API: + Attr attr(q_scale); + attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear) + attr.append_post_sum(sum_scale) + +3. append binary post op + dst = Binary[Conv(src, wei)] + +*/ +using kind_t = dnnl::primitive::kind; +struct PostOpParam { + // eltwise post op constructor + PostOpParam( + float scale, + float alpha, + float beta, + dnnl::algorithm algo, + kind_t kind) + : scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {} + // sum post op constructor + PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {} + // sum post op with zp + PostOpParam(float scale, int64_t zero_point, kind_t kind) + : scale_(scale), zero_point_(zero_point), kind_(kind) {} + // binary post op constructor + PostOpParam( + at::Tensor& binary, + dnnl::memory::desc& binary_md, + dnnl::memory::desc& expected_md, + dnnl::algorithm algo, + kind_t kind) + : binary_(binary), + meta_(binary_md), + expected_meta_(expected_md), + algo_(algo), + kind_(kind) {} + // prelu post op constructor + PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {} + + // post sum or binary with scale post op constructor + PostOpParam( + at::Tensor& binary, + float scale, + dnnl::algorithm algo, + kind_t kind) + : scale_(scale), binary_(binary), algo_(algo), kind_(kind) {} + + // for int8 sum/eltwise + float scale_ = 1.0; + int64_t zero_point_ = 0; + // for eltwise + float alpha_ = 0.0; + float beta_ = 0.0; + // for binary + at::Tensor binary_ = at::Tensor(); + at::Tensor expected_binary_ = at::Tensor(); + void* binary_ptr_ = nullptr; + dnnl::memory::desc meta_ = dnnl::memory::desc(); + dnnl::memory::desc expected_meta_ = dnnl::memory::desc(); + // for prelu + int mask_ = 0; + // common + dnnl::algorithm algo_ = dnnl::algorithm::eltwise_relu; + kind_t kind_ = kind_t::eltwise; +}; + +class Attr { + public: + Attr() : q_scale_(1.f) {} + Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} + + /***** eltwise *****/ + dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu; + dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic; + dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh; + dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf; + dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish; + dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear; + dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish; + dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt; + dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh; + dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square; + dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs; + dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp; + dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log; + dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round; + dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish; + dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu; + dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu; + dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow; + dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip; + // note: hardsigmoid seems oneDNN still not support + dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid; + + /***** binary *****/ + dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul; + dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add; + dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub; + dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div; + dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq; + dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne; + dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge; + dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt; + dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le; + dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt; + dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max; + dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min; + + // append sum post op + Attr& append_post_sum( + float sum_scale, + float sum_q_scale = 1.f, + int64_t zp = 0) { + ops_params_.push_back( + PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, zp, kind_t::sum)); + return *this; + } + + // append eltwise post op + Attr& append_post_eltwise( + float scale, + float alpha, + float beta, + dnnl::algorithm algo) { + ops_params_.push_back( + PostOpParam(scale, alpha, beta, algo, kind_t::eltwise)); + return *this; + } + + // append binary post op + template + Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) { + auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary; + bool binary_is_channels_last = + (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast || + binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d); + + if constexpr (!is_matmul) { + binary_ = binary_is_channels_last ? binary_ : binary_.contiguous(); + } + dnnl::memory::desc md = get_onednn_md(binary_); + auto expected_md = dnnl::memory::desc( + md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any); + if constexpr (is_matmul) { + ops_params_.push_back(PostOpParam(binary_, md, md, algo, kind_t::binary)); + } else { + ops_params_.push_back( + PostOpParam(binary_, md, expected_md, algo, kind_t::binary)); + } + + return *this; + } + + Attr& append_scale_binary( + dnnl::algorithm algo, + at::Tensor binary, + float scale, + float sum_q_scale = 1.f, + int64_t zp = 0) { + ops_params_.push_back(PostOpParam( + binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary)); + return *this; + } + + // append bias with binary_add method (only used for QConv now) + Attr& append_bias(const at::Tensor& binary, const int ndimension) { + // In PyTorch, bias are in shape of [OC], + // we expand its shape according to Conv dimension + // Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1] + at::Tensor binary_ = binary.contiguous(); + dnnl::memory::desc binary_md; + switch (ndimension) { + case 1: + binary_md = dnnl::memory::desc( + {binary.size(0), 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abc); + break; + case 2: + binary_md = dnnl::memory::desc( + {1, binary.size(0), 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abcd); + break; + case 3: + binary_md = dnnl::memory::desc( + {1, binary.size(0), 1, 1, 1}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abcde); + break; + default: + TORCH_INTERNAL_ASSERT( + 0, "XPU only supports append_bias for Conv1d, Conv2d and Conv3d."); + } + // In this case, expected_md = binary_md + ops_params_.push_back(PostOpParam( + binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary)); + return *this; + } + + // append prelu post op + Attr& append_post_prelu(int mask) { + ops_params_.push_back(PostOpParam(mask, kind_t::prelu)); + return *this; + } + + dnnl::post_ops extract_post_ops(const at::Tensor& dst) { + // this function is used to extract post ops params from the ops_params_ + // and put them into onednn post ops + for (size_t i = 0; i < ops_params_.size(); ++i) { + kind_t kind = ops_params_[i].kind_; + switch (kind) { + case kind_t::eltwise: { + dnnl::algorithm algo = ops_params_[i].algo_; + float alpha = ops_params_[i].alpha_; + float beta = ops_params_[i].beta_; + dnnl_post_ops_.append_eltwise(algo, alpha, beta); + break; + } + case kind_t::sum: { + float scale = ops_params_[i].scale_; + int64_t zero_point = ops_params_[i].zero_point_; + // TODO [Asymmetric]: + // Post-sum zp for gpu is not supported currently + dnnl_post_ops_.append_sum(scale, zero_point); + break; + } + case kind_t::binary: { + dnnl::algorithm algo = ops_params_[i].algo_; + auto expected_md = ops_params_[i].expected_meta_; + // In this case user may create src1 memory descriptor with + // format_tag::any or set a specific tag. However, in later case if + // tags mismatch with dst, it would result in suboptimal performance. + // So here we use format_tag::any to make sure the fast can be + // selected. + // Thus we use expected_md (with format_any) here to create pd instead + // of original md + dnnl_post_ops_.append_binary(algo, expected_md); + break; + } + default: + break; + } + } + + return dnnl_post_ops_; + } + + bool with_sum() { + for (size_t i = 0; i < ops_params_.size(); ++i) { + if (ops_params_[i].kind_ == kind_t::sum) { + return true; + } + } + return false; + } + + bool with_binary() { + for (size_t i = 0; i < ops_params_.size(); ++i) { + if (ops_params_[i].kind_ == kind_t::binary) { + return true; + } + } + return false; + } + + void construct_post_binary( + dnnl::primitive_desc& pd, + std::unordered_map& args) { + // This function is used to construct binary memory desc in binary post ops. + // According to oneDNN doc, the binary tensor can be in shape of + // [1, 1, 1, 1], tensor broadcast + // [1, C, 1, 1], channel broadcast + // [dst.shape], no broadcast and eltwise-wise binary operations on dst + + auto& engine = GpuEngineManager::Instance().get_engine(); + for (size_t i = 0; i < ops_params_.size(); ++i) { + kind_t kind = ops_params_[i].kind_; + if (kind == kind_t::binary) { + dnnl::memory binary_m; + auto binary = ops_params_[i].binary_; + auto md = ops_params_[i].meta_; + // qeury expected_md to achieve peak performance + auto expected_md = pd.query_md( + dnnl::query::exec_arg_md, + DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1); + + binary_m = at::native::onednn::make_onednn_memory( + md, engine, binary.data_ptr()); + + args.insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m}); + } + } + } + + float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32 + // to int8, only works for int8 case + int64_t q_zero_point_ = 0; + std::vector ops_params_; // series of post ops + dnnl::post_ops dnnl_post_ops_; +}; + +static inline void construct_attr_for_unary( + const std::string_view& unary_post_op, + const torch::List>& unary_post_op_args, + const std::string_view& unary_post_op_algorithm, + at::native::onednn::Attr& attr) { + if (unary_post_op == "relu") { + attr = attr.append_post_eltwise( + /* eltwise_scale */ 1.f, + /* alpha */ 0.f, + /* beta */ 0.f, + attr.kind_with_relu); + } else if (unary_post_op == "leaky_relu") { + auto alpha = unary_post_op_args[0].value().to(); + attr = attr.append_post_eltwise(1.0, alpha, 0.f, attr.kind_with_relu); + } else if (unary_post_op == "tanh") { + attr = attr.append_post_eltwise(1.0f, 0.0f, 0.0f, attr.kind_with_tanh); + } else if (unary_post_op == "gelu") { + auto post_algorithm = unary_post_op_algorithm == "none" + ? attr.kind_with_gelu_erf + : attr.kind_with_gelu_tanh; + attr = attr.append_post_eltwise(1.0f, 0.0f, 0.0f, post_algorithm); + } else if (unary_post_op == "hardtanh") { + auto alpha = unary_post_op_args[0].value().to(); + auto beta = unary_post_op_args[1].value().to(); + attr = attr.append_post_eltwise(1.0, alpha, beta, attr.kind_with_clip); + } else if (unary_post_op == "hardswish") { + attr = attr.append_post_eltwise( + 1.0f, 1.f / 6.f, 1.f / 2.f, attr.kind_with_hardswish); + } else if (unary_post_op == "swish") { + attr = attr.append_post_eltwise(1.0f, 1.0f, 0.0f, attr.kind_with_swish); + } else { + TORCH_CHECK( + unary_post_op == "none", + "onednn qlinear: unspported unary post op", + unary_post_op); + } +} + +static inline void construct_attr_by_post_op( + const std::string_view& binary_post_op, + double binary_alpha, + double input1_scale, + int64_t input1_zero_point, + std::optional accum, + const std::string_view& unary_post_op, + const torch::List>& unary_post_op_args, + const std::string_view& unary_post_op_algorithm, + at::native::onednn::Attr& attr) { + bool is_none_post_op = + (binary_post_op == "none" && unary_post_op == "none"); // not post-ops + bool is_unary_post_op_only = + (binary_post_op == "none" && unary_post_op != "none"); // ex., conv + relu + bool is_valid_binary_combination = + (binary_post_op == "add" || binary_post_op == "sum") && + (unary_post_op == "none" || unary_post_op == "relu"); + TORCH_INTERNAL_ASSERT( + is_unary_post_op_only || is_none_post_op || is_valid_binary_combination, + "Please provide valid combination of unary post operators and binary post operators"); + + if (binary_post_op == "none") { + construct_attr_for_unary( + unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr); + } else if (binary_post_op == "sum") { + if (unary_post_op == "none") { + if (input1_zero_point != 0) + attr = attr.append_post_eltwise( + /*scale*/ 1.f, + /*alpha*/ 1.f, + -input1_zero_point * input1_scale, + attr.kind_with_linear); + attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0); + } else if (unary_post_op == "relu") { + if (input1_zero_point != 0) + attr = attr.append_post_eltwise( + /*scale*/ 1.f, + /*alpha*/ 1.f, + -input1_zero_point * input1_scale, + attr.kind_with_linear); + attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0); + attr = attr.append_post_eltwise( + /* scale */ 1.f, + /* alpha */ 0.f, + /* beta */ 0.f, + attr.kind_with_relu); + } + } else if (binary_post_op == "add") { + TORCH_CHECK(accum.has_value()); + attr = attr.append_post_binary(attr.kind_with_binary_add, accum.value()); + if (unary_post_op == "relu") { + attr = attr.append_post_eltwise(1.f, 0.f, 0.f, attr.kind_with_relu); + } + } +} + +} // namespace at::native::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/DnnlExt.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/DnnlExt.h new file mode 100644 index 0000000000000000000000000000000000000000..d4fc77ed516ade2c1431b2dcbf6168d4055bf895 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/DnnlExt.h @@ -0,0 +1,594 @@ +#pragma once + +#include + +#include +#include +#include + +#include +#include + +namespace std { + +template <> +struct hash { + size_t operator()(dnnl::memory::dims const& vec) const { + size_t seed = vec.size(); + for (auto& i : vec) { + seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +} // namespace std + +using namespace dnnl; + +namespace at::native::onednn { + +class primitive_ext : public primitive { + static constexpr int max_args = 12; + + public: + primitive_ext(const primitive& base) : primitive(base) {} + primitive_ext(primitive&& base) : primitive(std::move(base)) {} + + /// Returns a memory descriptor. + /// + /// @note + /// There are also convenience methods + /// #dnnl::primitive_desc_base::src_desc(), + /// #dnnl::primitive_desc_base::dst_desc(), and others. + /// + /// @param what The kind of parameter to query; can be + /// #dnnl::query::src_md, #dnnl::query::dst_md, etc. + /// @param idx Index of the parameter. For example, convolution bias can + /// be queried with what = #dnnl::query::weights_md and idx = 1. + /// @returns The requested memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// parameter of the specified kind or index. + const_dnnl_memory_desc_t query_md(query what, int idx = 0) const { + std::vector valid_q{ + query::src_md, + query::diff_src_md, + query::weights_md, + query::diff_weights_md, + query::dst_md, + query::diff_dst_md, + query::workspace_md, + query::scratchpad_md, + query::exec_arg_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), [=](query q) { + return what == q; + })) + DNNL_THROW_ERROR( + dnnl_invalid_arguments, "memory descriptor query is invalid"); + + const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md( + this->get_primitive_desc(), dnnl::convert_to_c(what), idx); + + return cdesc ? cdesc : nullptr; + } + + /// Returns a source memory descriptor. + /// @param idx Source index. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter with index @p idx. + const_dnnl_memory_desc_t src_desc(int idx) const { + return query_md(query::src_md, idx); + } + + /// Returns a destination memory descriptor. + /// @param idx Destination index. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter with index @p idx. + const_dnnl_memory_desc_t dst_desc(int idx) const { + return query_md(query::dst_md, idx); + } + + /// Returns a weights memory descriptor. + /// @param idx Weights index. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter with index @p idx. + const_dnnl_memory_desc_t weights_desc(int idx) const { + return query_md(query::weights_md, idx); + } + + /// Returns a diff source memory descriptor. + /// @param idx Diff source index. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source parameter with index @p idx. + const_dnnl_memory_desc_t diff_src_desc(int idx) const { + return query_md(query::diff_src_md, idx); + } + + /// Returns a diff destination memory descriptor. + /// @param idx Diff destination index. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter with index @p idx. + const_dnnl_memory_desc_t diff_dst_desc(int idx) const { + return query_md(query::diff_dst_md, idx); + } + + /// Returns a diff weights memory descriptor. + /// @param idx Diff weights index. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter with index @p idx. + const_dnnl_memory_desc_t diff_weights_desc(int idx) const { + return query_md(query::diff_weights_md, idx); + } + + const_dnnl_memory_desc_t exec_arg_desc(int idx) const { + return query_md(query::exec_arg_md, idx); + } + + // Separate versions without the index argument for documentation + // purposes. + + /// Returns a source memory descriptor. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter. + const_dnnl_memory_desc_t src_desc() const { + return src_desc(0); + } + + /// Returns a destination memory descriptor. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter. + const_dnnl_memory_desc_t dst_desc() const { + return dst_desc(0); + } + + /// Returns a weights memory descriptor. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter. + const_dnnl_memory_desc_t weights_desc() const { + return weights_desc(0); + } + + /// Returns a diff source memory descriptor. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source memory with. + const_dnnl_memory_desc_t diff_src_desc() const { + return diff_src_desc(0); + } + + /// Returns a diff destination memory descriptor. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter. + const_dnnl_memory_desc_t diff_dst_desc() const { + return diff_dst_desc(0); + } + + /// Returns a diff weights memory descriptor. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter. + const_dnnl_memory_desc_t diff_weights_desc() const { + return diff_weights_desc(0); + } + + /// Returns the workspace memory descriptor. + /// @returns Workspace memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// workspace parameter. + const_dnnl_memory_desc_t workspace_desc() const { + return query_md(query::workspace_md, 0); + } + + /// Returns the scratchpad memory descriptor. + /// @returns scratchpad memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// scratchpad parameter. + /// @sa @ref dev_guide_attributes_scratchpad + const_dnnl_memory_desc_t scratchpad_desc() const { + return query_md(query::scratchpad_md, 0); + } + + inline memory make_memory( + const_dnnl_memory_desc_t md_t, + const engine& aengine, + void* handle = DNNL_MEMORY_ALLOCATE) const { + sycl_interop::memory_kind kind = dnnl::sycl_interop::memory_kind::usm; + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_sycl_interop_memory_create( + &c_memory, md_t, aengine.get(), convert_to_c(kind), handle), + "could not create a memory"); + return memory(c_memory); + } + + memory make_src(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(src_desc(), aengine, handle); + } + + memory make_weight(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(weights_desc(), aengine, handle); + } + + memory make_bias(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(weights_desc(1), aengine, handle); + } + + memory make_dst(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) + const { + return make_memory(dst_desc(), aengine, handle); + } + + memory make_scratchpad( + const engine& aengine, + void* handle = DNNL_MEMORY_ALLOCATE) const { + return make_memory(scratchpad_desc(), aengine, handle); + } + + size_t get_scratchpad_size() const { + return dnnl_memory_desc_get_size(scratchpad_desc()); + } + + memory make_args(int arg_class, const engine& aengine, void* handle) const { + switch (arg_class) { + case DNNL_ARG_SRC: + return make_src(aengine, handle); + case DNNL_ARG_WEIGHTS: + return make_weight(aengine, handle); + case DNNL_ARG_SCRATCHPAD: + return make_scratchpad(aengine, handle); + case DNNL_ARG_DST: + return make_dst(aengine, handle); + case DNNL_ARG_BIAS: + return make_bias(aengine, handle); + default: + TORCH_INTERNAL_ASSERT( + false, "unsupported argument class for primitive_ext"); + } + } + + template + void set_attribute(int slot, int arg_class, void* handle, M constructor) { + if (mem_arg_cache[slot]) + mem_arg_cache[slot].set_data_handle(handle); + else { + mem_arg_cache[slot] = constructor(); + c_args[slot].arg = arg_class; + c_args[slot].memory = mem_arg_cache[slot].get(); + } + } + + sycl::event execute( + const stream& astream, + const engine& aengine, + std::vector>&& handles, + int slot_off = 2) { + auto off = slot_off; + for (const auto& p : handles) { + auto& m_arg = mem_arg_cache[off]; + if (m_arg) + m_arg.set_data_handle(p.second); + else { + m_arg = make_args(p.first, aengine, p.second); + c_args[off].arg = p.first; + c_args[off].memory = m_arg.get(); + } + ++off; + } + + sycl::event return_event; + std::vector deps{}; + error::wrap_c_api( + dnnl_sycl_interop_primitive_execute( + this->get(), astream.get(), off, c_args, &deps, &return_event), + "could not execute a primitive"); + return return_event; + } + + private: + memory mem_arg_cache[max_args]; + dnnl_exec_arg_t c_args[max_args]; +}; + +// Specifies the combined data types of input and weight tensors. +// For example, f32 means both input and weight are FP32, +// bf16_int4 means input is BF16 and weight is INT4. +enum class joint_dtypes_t { f32 = 0, f16, bf16, int8, f16_int4, bf16_int4 }; + +// Specifies the transposition state of input and weight tensors. +// Convention: first letter = input, second letter = weight. +// 'n' = not transposed, 't' = transposed. +// For example, 'nt' means input is not transposed, weight is transposed. +enum class trans_type_t { nn = 0, nt, tn, tt }; + +// Specifies the type and placement of bias in the computation. +// 'none' = no bias, +// 'scalar' = a single scalar bias applied to all elements, +// 'm' = per-row bias (typically matched to input rows), +// 'n' = per-column bias (typically matched to output channels), +// 'mn' = full bias matrix matching the output dimensions. +enum class bias_type_t { none = 0, scalar, m, n, mn }; + +template +T concat(const T& t1, at::ScalarType d) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back((int64_t)d); + + return t; +} + +template +T concat(const T& t1, bool b) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back(b); + + return t; +} + +template +T concat(const T& t1, int b) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.push_back(b); + + return t; +} + +template +T concat(const T& t1, const T& t2) { + T t; + t.insert(t.end(), t1.begin(), t1.end()); + t.insert(t.end(), t2.begin(), t2.end()); + + return t; +} + +template +T1 concat(const T1& t1, const T2& t2, const Ts&... ts) { + return concat(concat(t1, t2), ts...); +} + +template +struct onednn_types_mapper; + +template <> +struct onednn_types_mapper { + static inline std::tuple + get() { + return std::make_tuple( + dnnl::memory::data_type::f16, dnnl::memory::data_type::u4); + } +}; + +template <> +struct onednn_types_mapper { + static inline std::tuple + get() { + return std::make_tuple( + dnnl::memory::data_type::bf16, dnnl::memory::data_type::u4); + } +}; + +// TODO: bias types maybe not right +static inline dnnl::memory::dims get_bias_type( + bias_type_t b_dims, + const int m, + const int n) { + switch (b_dims) { + case bias_type_t::none: + return {0}; + case bias_type_t::scalar: + return {1, 1}; + case bias_type_t::m: + return {m, 1}; + case bias_type_t::n: + return {1, n}; + case bias_type_t::mn: + return {m, n}; + default: + TORCH_INTERNAL_ASSERT(false, "unsupported bias type ..."); + } +} + +// TODO: use template specialization on struct +template +inline void get_strides( + memory::dims& src_strides, + memory::dims& wei_strides, + memory::dims& dst_strides, + const int64_t lda, + const int64_t ldb, + const int64_t ldc) {} + +template <> +inline void get_strides( + memory::dims& src_strides, + memory::dims& wei_strides, + memory::dims& dst_strides, + const int64_t lda, + const int64_t ldb, + const int64_t ldc) { + src_strides = {lda, 1}; + wei_strides = {1, ldb}; + dst_strides = {ldc, 1}; +} + +using primitive_cache = + at::native::onednn::lru_cache; + +template +struct matmul_primitive_cache_t { + static inline primitive_ext& get( + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, + const int64_t ldc, + const bias_type_t + b_dims, // for shapeless bias, not put it into template parameter + const int device_id, + F f_attr, + const int64_t scale_group_size, + const int64_t zp_group_size) { + auto& cached = get_cache(device_id); + memory::dims src_strides, wei_strides, dst_strides; + get_strides(src_strides, wei_strides, dst_strides, lda, ldb, ldc); + auto pri_key = at::native::onednn::concat( + src_strides, + wei_strides, + m, + n, + k, + int(b_dims), + int(scale_group_size), + int(zp_group_size)); + auto iter = cached.find(pri_key); + if (iter == cached.end()) { + auto [src_dt, wei_dt] = onednn_types_mapper::get(); + auto bias_dims = get_bias_type(b_dims, m, n); + + auto src_md = memory::desc({m, k}, src_dt, src_strides); + auto wei_md = memory::desc({k, n}, wei_dt, wei_strides); + auto dst_md = memory::desc({m, n}, src_dt, dst_strides); + auto bias_format = b_dims == bias_type_t::none + ? dnnl::memory::format_tag::undef + : dnnl::memory::format_tag::ab; + auto bias_md = + memory::desc(bias_dims, src_dt, bias_format); // {m, n} or {1, n} + + primitive_attr pattr; + f_attr(pattr); + + dnnl::matmul::primitive_desc matmul_pd; + auto aengine = + at::native::onednn::GpuEngineManager::Instance().get_engine( + device_id); + if (b_dims == bias_type_t::none) { + matmul_pd = dnnl::matmul::primitive_desc( + aengine, src_md, wei_md, dst_md, pattr); + } else { + matmul_pd = dnnl::matmul::primitive_desc( + aengine, src_md, wei_md, bias_md, dst_md, pattr); + } + + return cached.insert({pri_key, primitive_ext(dnnl::matmul(matmul_pd))}) + .first->second; + } else { + return iter->second; + } + } + + private: + static constexpr int max_cache_capacity = 512; + // if default constructor of primitive cache could read the environment + // variable then it'll save a lot of trouble + static inline thread_local std::array mappings; + + // this won't be needed if primitive_cache have good default constructor + static inline primitive_cache& get_cache(const int device_id) { + auto& mapping = mappings[device_id]; + if (mapping.max_size() == 0) { + mapping.resize(max_cache_capacity); + } + return mapping; + } +}; + +template +static inline primitive_ext& matmul_primitive_create_and_cache( + const trans_type_t Tt, + const bias_type_t b_dims, + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, + const int64_t ldc, + const int device_id, + F attr, + const int64_t scale_group_size, + const int64_t zp_group_size) { + switch (Tt) { + case trans_type_t::nt: + return matmul_primitive_cache_t::get( + m, + n, + k, + lda, + ldb, + ldc, + b_dims, + device_id, + attr, + scale_group_size, + zp_group_size); + default: + TORCH_INTERNAL_ASSERT(false, "unsupported trans type ..."); + } +} + +template +static inline primitive_ext& matmul_primitive_create_and_cache( + const joint_dtypes_t Ts, + const trans_type_t Tt, + const bias_type_t b_dims, + const int m, + const int n, + const int k, + const int64_t lda, + const int64_t ldb, // is weight ldb necessary? + const int64_t ldc, + const int device_id, + F attr, + const int64_t scale_group_size = 0, + const int64_t zp_group_size = 0) { + switch (Ts) { + case joint_dtypes_t::f16_int4: + return matmul_primitive_create_and_cache( + Tt, + b_dims, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); + case joint_dtypes_t::bf16_int4: + return matmul_primitive_create_and_cache( + Tt, + b_dims, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); + default: + TORCH_INTERNAL_ASSERT(false, "Only support int4 ..."); + } +} + +} // namespace at::native::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/LRUCache.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/LRUCache.h new file mode 100644 index 0000000000000000000000000000000000000000..9229c10bc57a36b6c34cf8583acb50c8e8082e30 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/LRUCache.h @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native::onednn { + +template < + class key_t, + class value_t, + template class map_t = std::unordered_map> +class lru_cache { + public: + using value_type = std::pair; + using list_type = std::list; + using list_iter = typename list_type::iterator; + using map_type = map_t; + using const_list_iter = typename list_type::const_iterator; + using size_type = typename list_type::size_type; + + explicit lru_cache(size_type capacity) : capacity_(capacity) {} + lru_cache() : capacity_(0) {} + + [[nodiscard]] size_type size() const noexcept { + return map_.size(); + } + [[nodiscard]] size_type max_size() const noexcept { + return capacity_; + } + [[nodiscard]] bool empty() const noexcept { + return vlist_.empty(); + } + + void resize(size_type new_capacity) { + capacity_ = new_capacity; + trim(); + } + + list_iter begin() noexcept { + return vlist_.begin(); + } + const_list_iter begin() const noexcept { + return vlist_.begin(); + } + list_iter end() noexcept { + return vlist_.end(); + } + const_list_iter end() const noexcept { + return vlist_.end(); + } + + void clear() noexcept { + map_.clear(); + vlist_.clear(); + } + + void swap(lru_cache& other) noexcept { + using std::swap; + swap(vlist_, other.vlist_); + swap(map_, other.map_); + swap(capacity_, other.capacity_); + } + + list_iter find(const key_t& key) { + auto it = map_.find(key); + if (it == map_.end()) + return end(); + vlist_.splice(vlist_.begin(), vlist_, it->second); + return it->second; + } + + std::pair insert(const value_type& value) { + auto it = map_.find(value.first); + if (it != map_.end()) { + // Move existing to front + vlist_.splice(vlist_.begin(), vlist_, it->second); + return {it->second, false}; + } + + // Insert new at front + vlist_.emplace_front(value); + map_[value.first] = vlist_.begin(); + + trim(); + + return {vlist_.begin(), true}; + } + + list_iter erase(list_iter pos) { + map_.erase(pos->first); + return vlist_.erase(pos); + } + + private: + void trim() { + while (map_.size() > capacity_) { + auto last = std::prev(vlist_.end()); + map_.erase(last->first); + vlist_.pop_back(); + } + } + + list_type vlist_; + map_type map_; + size_type capacity_; +}; + +} // namespace at::native::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..ac8645d3e4a50a382f49d4dff76966cbcd232e8e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/Utils.h @@ -0,0 +1,134 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#define ONEDNN_SUPPORT_DETERMINISTIC \ + (DNNL_VERSION_MAJOR >= 3 && DNNL_VERSION_MINOR >= 4) + +namespace at::native::onednn { + +dnnl::memory::format_tag get_dnnl_default_format( + int ndims, + bool is_channels_last = false, + bool allow_undef = false); + +dnnl::memory::data_type get_onednn_dtype( + const at::Tensor& tensor, + bool allow_undef = false); + +dnnl::memory::data_type get_onednn_dtype_include_double( + const at::Tensor& tensor, + bool allow_undef = false); + +bool is_supported_onednn_dtype(const at::Tensor& tensor); + +dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor); + +dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor); +dnnl::memory::desc get_onednn_md(const at::Tensor& tensor); + +bool onednn_strides_check(const at::Tensor& src); +bool is_broadcast(const at::Tensor& t); +void undo_broadcast_on_batch(at::Tensor& m1, at::Tensor& m2); +void undo_broadcast(at::Tensor& tensor); + +bool is_onednn_matmul_strides(const at::Tensor& tensor); + +bool is_broadcast_from_other_to_self( + const at::Tensor& self, + const at::Tensor& other); + +at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim); + +void apply_tf32_if_allowed(dnnl::primitive_attr& primitive_attr); + +bool binary_valid( + const at::Tensor& self, + const at::Tensor& other, + bool is_fusion = false); + +bool use_channels_last_for_conv( + const at::Tensor& src, + const at::Tensor& weight); + +dnnl::memory::format_tag conv_src_fmt( + const int64_t ndim, + const bool is_channels_last = false); + +dnnl::memory::dims compatible_weight_dims( + const int64_t ndim, + const int64_t groups, + const int64_t oc, + const int64_t ic, + const IntArrayRef wsizes); + +dnnl::memory::format_tag conv_weight_fmt( + const int64_t ndim, + const bool grouped = false, + const bool is_channels_last = false); + +template +dnnl::memory::dims compatible_dilation(Vec&& dilation) { + dnnl::memory::dims ret = dilation.vec(); + for (auto it = ret.begin(); it != ret.end(); it++) { + *it -= 1; + } + return ret; +} + +template +dnnl::memory dnnl_memory_from_host_scalar( + T host_value, + Tensor& holder, + dnnl::engine& engine) { + auto options = at::TensorOptions() + .dtype(c10::CppTypeToScalarType::value) + .device(kXPU); + holder = at::empty({1}, options).fill_(host_value); + dnnl::memory::desc md = get_onednn_md(holder); + dnnl::memory mem = make_onednn_memory(md, engine, holder.data_ptr()); + return mem; +} + +struct PartitionCache { + std::unordered_map, dnnl::graph::partition> partition_map_{}; + + // The first 8 bits are reserved + // bit 0: is int8 + // bit 1: is uint8 + // bit 2: fp16(0) / bf16(1) + // bit 3: is fp32 + // bit 4: is sdp pattern + // bit 5-7: N/A + // The rest of the bits depend upon the arguments provided + // However, down the line, we might have different bitsets for different + // patterns + dnnl::graph::partition& insert_partition_cache( + std::bitset<32>& patternID, + dnnl::graph::partition& p) { + partition_map_[patternID] = std::move(p); + return partition_map_[patternID]; + } + std::optional> find_partition( + std::bitset<32>& patternID) { + auto iter = partition_map_.find(patternID); + if (iter != partition_map_.end()) { + return iter->second; + } + return std::nullopt; + } +}; + +} // namespace at::native::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNN.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNN.h new file mode 100644 index 0000000000000000000000000000000000000000..e73cb73e8b1e7b301e7d4ff2bb3eb85404099811 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -0,0 +1,182 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native::onednn { + +TORCH_API sycl::event matmul( + at::Tensor& result, + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Tensor& b_raw, + bool m2_trans, + Attr attr, + const std::vector& deps = {}); + +TORCH_API sycl::event convolution( + at::Tensor& dst, + const at::Tensor& src, + const at::Tensor& weight, + const at::Tensor& bia, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + Attr& attr, + const std::vector& deps = {}); + +TORCH_API sycl::event convolution_backward_weights( + at::Tensor& diff_weight, + at::Tensor& diff_bia, + const at::Tensor& diff_dst, + const at::Tensor& src, + IntArrayRef diff_weight_aten_size, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + const std::vector& deps = {}); + +TORCH_API sycl::event convolution_backward_data( + at::Tensor& diff_src, + const at::Tensor& diff_dst, + const at::Tensor& weight, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool bias_defined, + const std::vector& deps = {}); + +TORCH_API sycl::event deconvolution( + at::Tensor& dst, + const at::Tensor& src, + const at::Tensor& weight, + const at::Tensor& bia, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dst_padding, + IntArrayRef dilation, + int64_t groups, + Attr& attr, + const std::vector& deps = {}); + +TORCH_API sycl::event deconvolution_backward_data( + at::Tensor& diff_src, + const at::Tensor& diff_dst, + const at::Tensor& weight, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + bool bias_defined, + const std::vector& deps = {}); + +TORCH_API sycl::event deconvolution_backward_weights( + at::Tensor& diff_weight, + at::Tensor& diff_bia, + const at::Tensor& diff_dst, + const at::Tensor& src, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + const std::vector& deps = {}); + +TORCH_API void woq_matmul_int4( + at::Tensor& result, // dst, [M, N] + const at::Tensor& mat1_, // src, [M, K] + const at::Tensor& mat2_, // quantized weight, [K/8, N] + const at::Tensor& scale, // [K/group_size, N] + const at::Tensor& zp, // [k/group_size, N] + int64_t group_size, + bool pri_cache = true); + +dnnl::memory::dims conv_dst_size( + int64_t ndim, + IntArrayRef src_tz, + IntArrayRef wgh_tz, + IntArrayRef padding_front_top_left, + IntArrayRef padding_back_bottom_right, + IntArrayRef stride, + IntArrayRef dilation); + +dnnl::memory::dims deconv_dst_size( + IntArrayRef src_size, + IntArrayRef wgh_size, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + IntArrayRef dst_padding, + int64_t groups); + +at::Tensor quantized_convolution( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + bool transposed, + int64_t groups, + at::Tensor output, + double inv_output_scale, + int64_t output_zero_point, + std::optional accum, + double accum_scale, + int64_t accum_zero_point, + std::optional output_dtype, + std::optional binary_attr, + std::optional binary_alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +void quantized_matmul( + at::Tensor mat1, // act + double input_scale, + int64_t input_zero_point, + at::Tensor mat2, // weight + at::Tensor& weight_scales, + at::Tensor& weight_zero_points, + at::Tensor& b_raw, + at::Tensor result, // output + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::optional other, // extra input for binary-post-op + double other_scale, + int64_t other_zero_point, + const std::string_view& binary_post_op, + double binary_alpha, + const std::string_view& unary_post_op, + torch::List>& unary_post_op_args, + std::string_view unary_post_op_algorithm, + bool m2_trnas); + +void gpu_float_sdpa( + int batch_size, + int seq_len_q, + int seq_len_kv, + int num_head_q, + int num_head_kv, + int head_dim_qk, + int head_dim_v, + const Tensor& query, + const Tensor& key, + const Tensor& value, + std::optional attn_mask, + bool is_causal, + float softmax_scale, + const Tensor& output); +} // namespace at::native::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNNContext.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNNContext.h new file mode 100644 index 0000000000000000000000000000000000000000..d919e87bfa5860263cb52a2adba77bf13183e801 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mkldnn/xpu/detail/oneDNNContext.h @@ -0,0 +1,90 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native::onednn { + +TORCH_XPU_API dnnl::memory make_onednn_memory( + dnnl::memory::desc md, + dnnl::engine& engine, + void* ptr); + +// Keep non-static and non-inline +bool set_onednn_verbose(int level); + +// GpuEngineManager singleton +struct TORCH_XPU_API GpuEngineManager { + static GpuEngineManager& Instance(); // Singleton + + dnnl::engine& get_engine( + DeviceIndex device_index = c10::xpu::current_device()) { + c10::xpu::check_device_index(device_index); + return *engine_pool[device_index]; + } + + dnnl::engine& get_engine(const Device& device) { + TORCH_INTERNAL_ASSERT(device.type() == kXPU); + return get_engine(device.index()); + } + + GpuEngineManager(GpuEngineManager const&) = delete; + GpuEngineManager& operator=(GpuEngineManager const&) = delete; + GpuEngineManager(GpuEngineManager&&) = default; + GpuEngineManager& operator=(GpuEngineManager&&) = default; + + protected: + GpuEngineManager(); + ~GpuEngineManager() = default; + + private: + std::vector> engine_pool; +}; + +// GpuStreamManager singleton +struct TORCH_XPU_API GpuStreamManager { + static GpuStreamManager& Instance(); // Singleton + + dnnl::stream& get_stream( + DeviceIndex device_index = c10::xpu::current_device()) { + auto stream = c10::xpu::getCurrentXPUStream(device_index); + auto priority = stream.priority(); + if (stream_pool[device_index][priority].find(stream) == + stream_pool[device_index][priority].end()) { + stream_pool[device_index][priority][stream] = + std::make_shared(dnnl::sycl_interop::make_stream( + GpuEngineManager::Instance().get_engine(device_index), + stream.queue())); + } + return *stream_pool[device_index][priority][stream]; + } + + GpuStreamManager(GpuStreamManager const&) = delete; + GpuStreamManager& operator=(GpuStreamManager const&) = delete; + GpuStreamManager(GpuStreamManager&&) = default; + GpuStreamManager& operator=(GpuStreamManager&&) = default; + + protected: + GpuStreamManager() { + c10::DeviceIndex device_count = c10::xpu::device_count_ensure_non_zero(); + stream_pool.resize(device_count); + } + ~GpuStreamManager() = default; + + private: + using stream_hash_map = + ska::flat_hash_map>; + std::vector< + std::array> + stream_pool; +}; + +} // namespace at::native::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/Copy.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..cd65d8ae00e655d05f15ef1f744d771fe0d4eadc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/Copy.h @@ -0,0 +1,14 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include + +namespace at::native::mps { + +at::Tensor& mps_copy_( + at::Tensor& dst, + const at::Tensor& src, + bool non_blocking); +void copy_blit_mps(void* dst, const void* src, size_t size); + +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..94ea7b8734b3e699fc45077ba3d844b8277ab8a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0)) + +@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel +- (MPSNDArray* __nullable)reshapeWithCommandBuffer:(__nullable id)cmdBuf + sourceArray:(MPSNDArray* __nonnull)sourceArray + shape:(MPSShape* __nonnull)shape + destinationArray:(MPSNDArray* __nullable)destinationArray; +@end + +@interface MPSNDArrayDescriptor () +@property(readwrite, nonatomic) BOOL preferPackedRows; +@end + +@interface MPSNDArray () +- (nonnull instancetype)initWithBuffer:(id _Nonnull)buffer + offset:(NSUInteger)offset + descriptor:(MPSNDArrayDescriptor* _Nonnull)descriptor; +- (MPSNDArray* __nullable)arrayViewWithShape:(MPSShape* _Nullable)shape strides:(MPSShape* _Nonnull)strides; +@end + +typedef NS_ENUM(NSInteger, MTLMathMode) { + MTLMathModeSafe = 0, + MTLMathModeRelaxed = 1, + MTLMathModeFast = 2, +}; + +typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions) { + MTLMathFloatingPointFunctionsFast = 0, + MTLMathFloatingPointFunctionsPrecise = 1, +}; + +@interface MTLCompileOptions () +@property(readwrite, nonatomic) MTLMathMode mathMode; +@property(readwrite, nonatomic) MTLMathFloatingPointFunctions mathFloatingPointFunctions; +@end + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..6290245083a443ee2cd8109c81d270ec9674f9f7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h @@ -0,0 +1,48 @@ +#pragma once + +#include + +#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) { + MPSGraphFFTScalingModeNone = 0L, + MPSGraphFFTScalingModeSize = 1L, + MPSGraphFFTScalingModeUnitary = 2L, +}; + +@interface FakeMPSGraphFFTDescriptor : NSObject +@property(readwrite, nonatomic) BOOL inverse; +@property(readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode; +@property(readwrite, nonatomic) BOOL roundToOddHermitean; ++ (nullable instancetype)descriptor; +@end + +@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor; + +@interface MPSGraph (SonomaOps) +- (MPSGraphTensor* _Nonnull)conjugateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)realPartOfTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)fastFourierTransformWithTensor:(MPSGraphTensor* _Nonnull)tensor + axes:(NSArray* _Nonnull)axes + descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)realToHermiteanFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor + axes:(NSArray* _Nonnull)axes + descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)HermiteanToRealFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor + axes:(NSArray* _Nonnull)axes + descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; +@end + +// define BFloat16 enums for MacOS13 +#define MPSDataTypeBFloat16 ((MPSDataType)(MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16)) + +// define Metal version +#define MTLLanguageVersion3_1 ((MTLLanguageVersion)((3 << 16) + 1)) +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..5497c83f7b9a684cd70716cf605d66c2461ce9bf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h @@ -0,0 +1,196 @@ +#pragma once +#include + +// TODO: Remove me when moved to MacOS 13 +#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2)) + +@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject + +@property(readwrite, nonatomic) NSUInteger strideInX; +@property(readwrite, nonatomic) NSUInteger strideInY; +@property(readwrite, nonatomic) NSUInteger strideInZ; +@property(readwrite, nonatomic) NSUInteger dilationRateInX; +@property(readwrite, nonatomic) NSUInteger dilationRateInY; +@property(readwrite, nonatomic) NSUInteger dilationRateInZ; + +@property(readwrite, nonatomic) NSUInteger paddingLeft; +@property(readwrite, nonatomic) NSUInteger paddingRight; +@property(readwrite, nonatomic) NSUInteger paddingTop; +@property(readwrite, nonatomic) NSUInteger paddingBottom; +@property(readwrite, nonatomic) NSUInteger paddingFront; +@property(readwrite, nonatomic) NSUInteger paddingBack; + +@property(readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle; +@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout; +@property(readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout; + +@property(readwrite, nonatomic) NSUInteger groups; + +@end + +@compatibility_alias MPSGraphConvolution3DOpDescriptor FakeMPSGraphConvolution3DOpDescriptor; + +#endif + +@interface MPSGraph (VenturaOps) + +#if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) { + MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L, + MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L, + MPSGraphResizeNearestRoundingModeCeil = 2L, + MPSGraphResizeNearestRoundingModeFloor = 3L, + MPSGraphResizeNearestRoundingModeRoundToEven = 4L, + MPSGraphResizeNearestRoundingModeRoundToOdd = 5L, +}; + +// Define complex enums for MacOS 12 +#define MPSDataTypeComplexBit 0x01000000 +#define MPSDataTypeComplexFloat32 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64)) +#define MPSDataTypeComplexFloat16 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32)) +#endif + +- (MPSGraphTensor* _Nonnull)convolution3DWithSourceTensor:(MPSGraphTensor* _Nonnull)source + weightsTensor:(MPSGraphTensor* _Nonnull)weights + descriptor:(MPSGraphConvolution3DOpDescriptor* _Nonnull)descriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull) + convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient + weightsTensor:(MPSGraphTensor* _Nonnull)weights + outputShape:(MPSShape* _Nonnull)outputShape + forwardConvolutionDescriptor: + (MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull) + convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient + sourceTensor:(MPSGraphTensor* _Nonnull)source + outputShape:(MPSShape* _Nonnull)outputShape + forwardConvolutionDescriptor: + (MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axis:(NSInteger)axis + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + descending:(BOOL)descending + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)argSortWithTensor:(MPSGraphTensor* _Nonnull)tensor + axisTensor:(MPSGraphTensor* _Nonnull)axisTensor + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)inverseOfTensor:(MPSGraphTensor* _Nonnull)inputTensor name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor + sizeTensor:(MPSGraphTensor* _Nonnull)size + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + centerResult:(BOOL)centerResult + alignCorners:(BOOL)alignCorners + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor* _Nonnull)gradient + input:(MPSGraphTensor* _Nonnull)input + scaleOffsetTensor:(MPSGraphTensor* _Nonnull)scaleOffset + layout:(MPSGraphTensorNamedDataLayout)layout + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source + coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates + layout:(MPSGraphTensorNamedDataLayout)layout + normalizeCoordinates:(BOOL)normalizeCoordinates + relativeCoordinates:(BOOL)relativeCoordinates + alignCorners:(BOOL)alignCorners + paddingMode:(MPSGraphPaddingMode)paddingMode + samplingMode:(MPSGraphResizeMode)samplingMode + constantValue:(double)constantValue + name:(NSString* _Nullable)name; + +- (MPSGraphTensor* _Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor* _Nonnull)source + coordinateTensor:(MPSGraphTensor* _Nonnull)coordinates + layout:(MPSGraphTensorNamedDataLayout)layout + normalizeCoordinates:(BOOL)normalizeCoordinates + relativeCoordinates:(BOOL)relativeCoordinates + alignCorners:(BOOL)alignCorners + paddingMode:(MPSGraphPaddingMode)paddingMode + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode + constantValue:(double)constantValue + name:(NSString* _Nullable)name; +- (MPSGraphTensor* _Nonnull)truncateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name; + +@end diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MetalShaderLibrary.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MetalShaderLibrary.h new file mode 100644 index 0000000000000000000000000000000000000000..535edd29ebd7ada7e78fadf8665fefeb3844865b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/MetalShaderLibrary.h @@ -0,0 +1,178 @@ +#pragma once +#ifdef __OBJC__ +#include +typedef id MTLLibrary_t; +typedef id MTLFunction_t; +typedef id MTLComputePipelineState_t; +typedef id MTLComputeCommandEncoder_t; +#else +typedef void MTLCompileOptions; +typedef void* MTLLibrary_t; +typedef void* MTLFunction_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLComputeCommandEncoder_t; +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declaration of TensorBase and TensorIteratorBase +namespace at { +class TensorBase; +struct TensorIteratorBase; +} // namespace at + +namespace at::native::mps { + +namespace detail { +template +class has_size_type { + template + static constexpr std::true_type check(typename U::size_type*); + template + static constexpr std::false_type check(...); + + public: + static constexpr bool value = decltype(check(nullptr))::value; +}; + +template +constexpr bool has_size_type_v = has_size_type::value; + +} // namespace detail + +// Returns `gpuAddress` of respective `id` plus storage offset +void* get_tensor_gpu_address(const at::TensorBase&); + +class MetalKernelFunction { + public: + MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_); + ~MetalKernelFunction(); + MetalKernelFunction(MetalKernelFunction&) = delete; + // Shader properties + uint64_t getMaxThreadsPerThreadgroup() const; + uint64_t getThreadExecutionWidth() const; + uint64_t getStaticThreadGroupMemoryLength() const; + void runCommandBlock(std::function f); + // Methods below should be called from runCommandBlock function + void startEncoding(); + void setArg(unsigned idx, const at::TensorBase& t); + void setArg(unsigned idx, const void* ptr, uint64_t size); + template < + typename T, + typename = std::enable_if_t< + std::is_integral_v || std::is_same_v || + (std::is_class_v && std::is_trivially_copyable_v && + !detail::has_size_type_v)>> + inline void setArg(unsigned idx, const T val) { + setArg(idx, &val, sizeof(T)); + } + + template < + typename Container, + typename = std::enable_if_t>> + inline void setArg(unsigned idx, const Container& values) { + setArg( + idx, + values.data(), + values.size() * sizeof(typename Container::value_type)); + } + void dispatch( + uint64_t length, + std::optional groupSize = std::nullopt); + void dispatch( + c10::ArrayRef length, + c10::OptionalArrayRef groupSize = std::nullopt); + + private: + MTLComputePipelineState_t cps; + MTLFunction_t func; + MTLComputeCommandEncoder_t encoder = nullptr; +}; + +class MetalShaderLibrary { + public: + MetalShaderLibrary(std::string src) + : shaderSource(std::move(src)), nparams(0), compile_options(nullptr) {} + MetalShaderLibrary(std::string src, unsigned nparams_) + : shaderSource(std::move(src)), + nparams(nparams_), + compile_options(nullptr) {} + MetalShaderLibrary( + std::string src, + unsigned nparams_, + MTLCompileOptions* compile_options_) + : shaderSource(std::move(src)), + nparams(nparams_), + compile_options(compile_options_) {} + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + virtual ~MetalShaderLibrary(); + std::vector getFunctionNames(); + std::shared_ptr getKernelFunction( + const std::string& name); + inline MTLComputePipelineState_t getPipelineStateForFunc( + const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).first; + } + MTLComputePipelineState_t getPipelineStateForFunc( + const std::string& fname, + const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).first; + } + inline MTLFunction_t getMTLFunction(const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).second; + } + MTLFunction_t getMTLFunction( + const std::string& fname, + const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).second; + } + static MetalShaderLibrary& getBundledLibrary(); + void exec_unary_kernel( + TensorIteratorBase& iter, + const std::string& name, + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); + void exec_binary_kernel( + TensorIteratorBase& iter, + const std::string& name, + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); + + protected: + virtual MTLLibrary_t getLibrary(); + virtual MTLLibrary_t getLibrary( + const std::initializer_list& params); + MTLLibrary_t library = nullptr; + + private: + std::pair getLibraryPipelineState( + MTLLibrary_t lib, + const std::string& fname); + MTLLibrary_t compileLibrary(const std::string& src); + std::string shaderSource; + unsigned nparams; + MTLCompileOptions* compile_options; + std::unordered_map libMap; + std::unordered_map< + std::string, + std::pair> + cplMap; +}; + +class DynamicMetalShaderLibrary : public MetalShaderLibrary { + public: + DynamicMetalShaderLibrary(const std::string& src) : MetalShaderLibrary(src) { + // Compile right away + getLibrary(); + } + ~DynamicMetalShaderLibrary() override; +}; + +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/OperationUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/OperationUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..976b62c7ac4b42706e6804e942f01f8babc0c39a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/OperationUtils.h @@ -0,0 +1,656 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +@interface MPSGraph (PyTorchFixups) +- (MPSGraphTensor*)minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor + secondaryTensor:(MPSGraphTensor*)secondaryTensor + name:(NSString*)name; + +- (MPSGraphTensor*)maximumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor + secondaryTensor:(MPSGraphTensor*)secondaryTensor + name:(NSString*)name; +@end + +using namespace at::mps; + +namespace at::native::mps { + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); + +struct MPSScalar { + id getMTLBuffer() const { + return __builtin_bit_cast(id, buffer.get()); + } + + size_t size = 0; + ScalarType type = ScalarType::Undefined; + c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope) + union { + float f; // MPS doesn't support 'double' + at::Half h; + int64_t i; + bool b; + c10::complex cf; + c10::complex ch; + at::BFloat16 bf16; + } value{}; +}; + +void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results); + +MPSDataType getMPSDataType(ScalarType scalar_type); +static inline MPSDataType getMPSDataType(const TensorBase& t) { + return getMPSDataType(t.scalar_type()); +} +MPSDataType getMPSScalarType(ScalarType scalar_type); +static inline MPSDataType getMPSScalarType(const TensorBase& t) { + return getMPSScalarType(t.scalar_type()); +} +MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); +std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); +static inline std::string getMPSTypeString(const TensorBase& t, bool short_name = false) { + return getMPSTypeString(t.scalar_type(), short_name); +} +std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); +static inline std::string scalarToMetalTypeString(const TensorBase& t) { + return scalarToMetalTypeString(t.scalar_type()); +} +NSArray* getTensorAxes(const TensorBase& t); +NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim); +std::string getMPSShapeString(MPSShape* shape); +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false); +std::string getArrayRefString(const IntArrayRef s); +// use has_storage() on the returned tensor to determine if src actually is a view +Tensor gatherViewTensor(const Tensor& src, Tensor& dst); +Tensor& scatterViewTensor(const Tensor& src, Tensor& output); +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64 = false); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64 = false); + +MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray); +MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); +MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil); +// The MPSShape could vary based on memory format +Tensor getTensorView(const Tensor& t, MPSShape* shape); +MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); +MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); + +static inline id getMTLBufferStorage(const TensorBase& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +class Placeholder { + public: + Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray); + Placeholder(MPSGraphTensor* mpsGraphTensor, + const Tensor& self, + MPSShape* mpsShape = nullptr, + bool gatherTensorData = true, + MPSDataType dataType = MPSDataTypeInvalid, + bool useMPSStridedAPI = true); + MPSGraphTensor* getMPSGraphTensor() { + return _placeholder; + } + MPSGraphTensorData* getMPSGraphTensorData() { + return _value; + } + bool isIntermediate() { + return _value == nullptr; + } + + private: + MPSGraphTensor* _placeholder; + MPSGraphTensorData* _value; + Tensor _tensor; +}; + +void resize_tensor(Tensor* output); +Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device); +MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor); +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType); +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); +MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor); +MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); + +MPSGraph* make_mps_graph(); +void printTensorNDArray(const TensorBase& t); +MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType); + +MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar); + +std::string get_mem_format_string(c10::MemoryFormat memory_format); + +using MPSCacheKey = uint64_t; + +struct MPSCachedKernel { + MPSCachedKernel(NSObject* object) : _object([object retain]) {} + virtual ~MPSCachedKernel() { + [_object release]; + _object = nullptr; + } + + // Delete copy constructor and assignment + MPSCachedKernel(const MPSCachedKernel&) = delete; + void operator=(const MPSCachedKernel&) = delete; + + template + inline T* kernel() const { + return (T*)_object; + } + + private: + NSObject* _object = nullptr; +}; + +// derive this class to cache a graph and its inputs/outputs +// can be used to store any NSObject +struct MPSCachedGraph { + MPSCachedGraph(NSObject* object) : _object([object retain]) {} + virtual ~MPSCachedGraph() { + [_object release]; + _object = nullptr; + } + + template + inline T* as() { + return static_cast(this); + } + + MPSGraph* graph() const { + return (MPSGraph*)_object; + } + NSObject* object() const { + return _object; + } + + private: + NSObject* _object = nullptr; +}; + +struct MPSUnaryCachedGraph : public MPSCachedGraph { + MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; +}; + +struct MPSUnaryGradCachedGraph : public MPSCachedGraph { + MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; // some backward input is actually the forward's output + MPSGraphTensor* gradInputTensor_ = nil; +}; + +struct MPSBinaryCachedGraph : public MPSCachedGraph { + MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* otherTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; +}; + +struct MPSBinaryGradCachedGraph : public MPSCachedGraph { + MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* otherTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; +}; + +struct MPSKernelCache { + typedef MPSCachedKernel* (^CreateCachedKernelBlock)(); + + struct CacheEntry { + CacheEntry(const std::string& key, MPSCachedKernel* cachedKernel) : cachedKernel_(cachedKernel), key_(key) {} + MPSCachedKernel* cachedKernel_ = nullptr; + std::string key_; + }; + + public: + static MPSKernelCache* getInstance() { + if (_instance_cache == nullptr) { + _instance_cache = new MPSKernelCache(); + } + return _instance_cache; + } + + ~MPSKernelCache() { + dispatch_release(serialQueue_); + for (const auto& i : cache_) { + delete i.second.cachedKernel_; + } + } + + // Disallow the copy constructor and operator= functions + MPSKernelCache(const MPSKernelCache&) = delete; + void operator=(const MPSKernelCache&) = delete; + + MPSCachedKernel* CreateCachedKernel(const std::string& key, CreateCachedKernelBlock createCacheBlock) { + __block MPSCachedKernel* cachedKernel = nil; + MPSCacheKey hash = std::hash{}(key); + dispatch_sync_with_rethrow(serialQueue_, ^() { + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); + cachedKernel = entry.cachedKernel_; + } else { + cachedKernel = createCacheBlock(); + CacheEntry entry(key, cachedKernel); + cache_.emplace(hash, entry); + } + }); + return cachedKernel; + } + template + inline T* CreateCachedKernelAs(const std::string& key, CreateCachedKernelBlock createCacheBlock) { + return static_cast(CreateCachedKernel(key, createCacheBlock)); + } + + MPSCachedKernel* LookUp(const std::string& key) const { + __block MPSCachedKernel* cachedKernel = nil; + + MPSCacheKey hash = std::hash{}(key); + dispatch_sync_with_rethrow(serialQueue_, ^() { + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); + cachedKernel = entry.cachedKernel_; + } + }); + return cachedKernel; + } + + template + inline T* LookUpAs(const std::string& key) const { + return static_cast(LookUp(key)); + } + + private: + MPSKernelCache() { + serialQueue_ = dispatch_queue_create("kernel cache queue", DISPATCH_QUEUE_SERIAL); + } + + static MPSKernelCache* _instance_cache; + std::unordered_map cache_; + dispatch_queue_t serialQueue_ = nullptr; +}; + +// Common template for creating cached kernel if missing +template +inline T* LookUpOrCreateCachedKernel(const std::string& key, std::function instantiate) { + auto cache_ = MPSKernelCache::getInstance(); + if (auto rc = cache_->LookUpAs(key)) { + return rc; + } + return cache_->CreateCachedKernelAs(key, ^mps::MPSCachedKernel*() { + auto k_ = new mps::MPSCachedKernel(instantiate()); + return k_; + }); +} + +// TODO: Improve the overall design of MPSGraphCache. +// https://github.com/pytorch/pytorch/issues/77176 +// Cache holding various keys mapped to graphs +struct MPSGraphCache { + typedef MPSCachedGraph* (^CreateCachedGraphBlock)(); + + struct CacheEntry { + CacheEntry(const std::string& key, MPSCachedGraph* cachedGraph) : cachedGraph_(cachedGraph), key_(key) {} + MPSCachedGraph* cachedGraph_ = nullptr; + std::string key_; + }; + + public: + static MPSGraphCache* getInstance() { + if (_instance_cache == nullptr) { + _instance_cache = new MPSGraphCache(); + } + return _instance_cache; + } + + ~MPSGraphCache() { + dispatch_release(serialQueue_); + + for (const auto& i : cache_) { + delete i.second.cachedGraph_; + } + } + + // Disallow the copy constructor and operator= functions + MPSGraphCache(const MPSGraphCache&) = delete; + void operator=(const MPSGraphCache&) = delete; + + MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + __block MPSCachedGraph* cachedGraph = nil; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync_with_rethrow(serialQueue_, ^() { + // verify the cached entry doesn't already exist + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + } else { + cachedGraph = createCacheBlock(); + CacheEntry entry(key, cachedGraph); + cache_.emplace(hash, entry); + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + return static_cast(CreateCachedGraph(key, createCacheBlock)); + } + + MPSCachedGraph* LookUp(const std::string& key) const { + __block MPSCachedGraph* cachedGraph = nullptr; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync(serialQueue_, ^() { + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* LookUpAs(const std::string& key) const { + return static_cast(LookUp(key)); + } + + private: + MPSGraphCache() { + serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); + } + // this is defined in OperationUtils.mm to not include + // MPSProfiler.h in header OperationUtils.h + void profileCachedGraph(const CacheEntry& cacheEntry) const; + + static MPSGraphCache* _instance_cache; + std::unordered_map cache_; + dispatch_queue_t serialQueue_ = nullptr; +}; + +// Common template for creating graph with a specified cache if missing +template +inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function instantiate) { + auto cache_ = MPSGraphCache::getInstance(); + if (auto rc = cache_->LookUpAs(key)) { + return rc; + } + return cache_->CreateCachedGraphAs(key, ^mps::MPSCachedGraph*() { + T* newCachedGraph = nil; + @autoreleasepool { + // Initialize graph + auto mpsGraph = mps::make_mps_graph(); + newCachedGraph = new T(mpsGraph); + instantiate(mpsGraph, newCachedGraph); + } + return newCachedGraph; + }); +} + +// Common math operations +MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); + +#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \ + if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \ + TORCH_WARN_ONCE( \ + "MPS: no support for int64 for ", \ + op_name, \ + ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \ + } + +/** + * Returns distance from lowest to highest element offset in given tensor. + */ +size_t compute_storage_numel_distance(const TensorBase& t); + +/** + * Checks whether tensor is mapped to a contiguous area in the storage. + */ +inline bool is_dense_in_storage(const TensorBase& t) { + return compute_storage_numel_distance(t) == static_cast(t.numel()); +} + +template , encoder_t> || + std::is_same_v, encoder_t>>> +static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) { + if (C10_UNLIKELY(t.device().type() == kCPU)) { + if constexpr (std::is_same_v, encoder_t>) { + TORCH_CHECK(t.dim() == 0, "Passed CPU tensor to MPS op"); + // MPS does not support doubles, silently downcast CPU scalar to float + if (C10_UNLIKELY(t.scalar_type() == kDouble)) { + auto val = static_cast(*reinterpret_cast(t.const_data_ptr())); + [encoder setBytes:&val length:sizeof(val) atIndex:idx]; + return; + } + if (C10_UNLIKELY(t.scalar_type() == kComplexDouble)) { + auto val = static_cast>(*reinterpret_cast*>(t.const_data_ptr())); + [encoder setBytes:&val length:sizeof(val) atIndex:idx]; + return; + } + [encoder setBytes:t.storage().data() length:t.element_size() atIndex:idx]; + } else { + TORCH_CHECK(false, "Passed CPU tensor to MPS op"); + } + return; + } + [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx]; +} + +// Implementation of setBytes for containers vs trivially copiable types must be separate +// Containers like `std::array` could have been uploaded directly, but `c10::ArrayRef`, +// while trivially copiable, includes padding which if copied as Metal shader parameters +// might overwrite other values +template < + typename T, + typename = std::enable_if_t || std::is_same_v || + (std::is_class_v && std::is_trivially_copyable_v && !detail::has_size_type_v)>> +static inline void mtl_setBytes(id encoder, const T val, unsigned idx) { + [encoder setBytes:&val length:sizeof(T) atIndex:idx]; +} + +template >> +static inline void mtl_setBytes(id encoder, const Container& values, unsigned idx) { + [encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx]; +} + +static inline void mtl_setBytes(id encoder, const MPSScalar& s, unsigned idx) { + [encoder setBytes:&s.value length:s.size atIndex:idx]; +} + +static size_t iter_tensor_offset(TensorIteratorBase& iter, unsigned idx) { + // At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id object + // But TensorIterator constructs data_ptr as if base was just a raw pointer + // Workaround this problem by computing an offset from the start of the tensor, which works for both + // tensor views and sliced 64-bit iterators + return reinterpret_cast(iter.data_ptr(idx)) - + reinterpret_cast(iter.tensor_base(idx).storage().data()); +} + +static inline void bind_iter_tensors(id encoder, + TensorIteratorBase& iter, + std::optional ntensors = std::nullopt) { + for (auto idx : c10::irange(ntensors.value_or(iter.ntensors()))) { + auto& t = iter.tensor_base(idx); + // Handle CPU scalars + if (C10_UNLIKELY(t.device().type() == kCPU)) { + mtl_setBuffer(encoder, t, idx); + continue; + } + auto offs = iter_tensor_offset(iter, idx); + [encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx]; + } +} + +namespace detail { +template +inline void mtl_setArg(id encoder, const T& val, unsigned idx) { + mtl_setBytes(encoder, val, idx); +} + +inline void mtl_setArg(id encoder, id val, unsigned idx) { + [encoder setBuffer:val offset:0 atIndex:idx]; +} + +template <> +inline void mtl_setArg(id encoder, const Tensor& val, unsigned idx) { + mtl_setBuffer(encoder, val, idx); +} + +template <> +inline void mtl_setArg(id encoder, const std::optional& val, unsigned idx) { + if (val.has_value()) { + mtl_setBuffer(encoder, val.value(), idx); + } +} + +template <> +inline void mtl_setArg(id encoder, const TensorBase& val, unsigned idx) { + mtl_setBuffer(encoder, val, idx); +} +// MPS does not support doubles, so cast it down to float before passing as an argument +template <> +inline void mtl_setArg(id encoder, const double& val, unsigned idx) { + float val_f = static_cast(val); + mtl_setBytes(encoder, val_f, idx); +} +} // namespace detail + +template +static inline void mtl_setArgs(id encoder, const T& val) { + detail::mtl_setArg(encoder, val, idx); +} + +template +static inline void mtl_setArgs(id encoder, const T& val, Args&&... args) { + detail::mtl_setArg(encoder, val, idx); + mtl_setArgs(encoder, std::forward(args)...); +} + +static inline void mtl_dispatch1DJob(id encoder, + id cplState, + NSUInteger length) { + static_assert(sizeof(NSUInteger) == sizeof(uint64_t)); + const auto maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1); + [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; +} + +id generateKernelDataOffsets(id commandEncoder, + const TensorIteratorBase& iter, + bool use_64bit_index = false); + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) { + return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()}; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) { + return @{ + p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) { + return @{ + p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) { + return @{ + p1.getMPSGraphTensor() : p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor() : p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor() : p3.getMPSGraphTensorData(), + p4.getMPSGraphTensor() : p4.getMPSGraphTensorData(), + }; +} + +inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) { + runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result)); +} + +inline bool supportsComplex() { + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); +} + +// MPS yet to support double types, but starting from MacOS 14, supports bfloat16 +inline bool supportedFloatingType(ScalarType dtype) { + return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; +} + +inline bool supportedFloatingType(const TensorBase& t) { + return supportedFloatingType(t.scalar_type()); +} + +inline bool supportedFloatingOrComplexType(ScalarType dtype) { + if (dtype == kComplexFloat || dtype == kComplexHalf) { + return supportsComplex(); + } + return supportedFloatingType(dtype); +} +inline bool supportedFloatingOrComplexType(const TensorBase& t) { + return supportedFloatingOrComplexType(t.scalar_type()); +} + +inline void checkSupportsBFloat16() { + TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), + "MPS bfloat16 type is supported on MacOS 14.0 or newer."); +} + +inline bool needsGather(const TensorBase& t) { + static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); + return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); +} + +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/TensorFactory.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/TensorFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..22a49c1106f278cc6eb54fed59547f00551c8246 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/TensorFactory.h @@ -0,0 +1,14 @@ +// Copyright © 2022 Apple Inc. + +#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) AT_DISPATCH_CASE( \ + at::ScalarType::Half, \ + __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)) diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h new file mode 100644 index 0000000000000000000000000000000000000000..c2c3c1d5d4587afea6c3324c0d0bfaa262573ec3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h @@ -0,0 +1,20 @@ +#pragma once + +#ifndef __METAL__ +#include +using ulong = unsigned long; +#define _ARRAY_NS std +#else +#include +#define _ARRAY_NS metal +#endif + +template +struct UpsampleParams { + _ARRAY_NS::array input_strides; + _ARRAY_NS::array input_sizes; + _ARRAY_NS::array output_strides; + _ARRAY_NS::array output_sizes; + _ARRAY_NS::array scales; + bool align_corners; +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..36d60913ea49fa8afeed122adbe282d5d5a8a5f5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h @@ -0,0 +1,10 @@ +#pragma once + +namespace at::native::mps { +void binary_op_kernel( + const std::string func_name, + const Tensor& input, + const Tensor& other, + const Tensor& output, + const std::optional alpha = std::nullopt); +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..b43e38a77b18e87319966863c4ece413a593a963 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h @@ -0,0 +1,38 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adam_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..f7d393710ee2f7ff6bda1034bca3e63a190ab464 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h @@ -0,0 +1,35 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adam_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..b2195a65e7bc788f070287a76541b4d907b53f60 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h @@ -0,0 +1,37 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adamw_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..40b4ddd1922318adeba4ac003c12590d659790b1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace at::native::mps { + +void _fused_adamw_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/Indexing.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/Indexing.h new file mode 100644 index 0000000000000000000000000000000000000000..f52e5cd7334c3f0990bd5cbebeafb0483f66d79a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/Indexing.h @@ -0,0 +1,8 @@ +// Copyright © 2022 Apple Inc. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +using namespace at::mps; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h new file mode 100644 index 0000000000000000000000000000000000000000..71575189a7584d099e0355d26fbc0f25c30a01d8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h @@ -0,0 +1,362 @@ +#pragma once +#include +#include +#include + +static_assert(sizeof(bool) == 1); + +namespace at::native::mps { + +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kmaxThreadGroups = 32; +static constexpr int64_t kmaxTensors = 32; + +struct MetadataArguments { // the size of this struct must be less than 4 kilobytes + uint64_t numels[kmaxTensors]; + uint64_t threadgroup_to_tensor[kmaxThreadGroups]; + uint64_t threadgroup_to_chunk[kmaxThreadGroups]; +}; + +struct FusedAdamEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize) const { + mtl_setArgs( + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize) const { + mtl_setArgs( + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); + } +}; + +template +struct FusedSgdEncodingFunctor {}; + +template <> +struct FusedSgdEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double momentum, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step) const { + mtl_setArgs(computeEncoder, + tensorArgumentBuffer, + metadata_arguments, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double momentum, + const at::Tensor& lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step) const { + mtl_setArgs(computeEncoder, + tensorArgumentBuffer, + metadata_arguments, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step); + } +}; + +template <> +struct FusedSgdEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double lr, + const bool maximize) const { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const at::Tensor& lr, + const bool maximize) const { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); + } +}; + +std::pair, id> getFusedAdamCPLState(const std::string& fname); +template +static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name, + std::vector>& tensor_lists, + at::TensorList state_steps, + encoder_func_t encode, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth"); + for (const auto& d : c10::irange(depth)) { + const auto scalar_type = tensor_lists[d][0].scalar_type(); + TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16, + "Only float, bfloat and half are supported"); + } + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + // Remove comment for debugging + /* + mpsStream->addCompletedHandler(^(id cb) { + [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) { + NSLog(@"MPSStream: %@", log); + } + ]; + }); + */ + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [fusedOptimizerPSO, fusedOptimizerFunc] = getFusedAdamCPLState(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]}); + + [computeEncoder setComputePipelineState:fusedOptimizerPSO]; + + // BufferIndex is the index in the kernel function + auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease]; + id tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int64_t tensor_loc = 0; + int64_t threadgroup_loc = 0; + MetadataArguments metadata_arguments; + + for (const auto tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + + for (const auto& d : c10::irange(depth)) { + mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) + usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + if (!state_steps.empty()) { + mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + } + metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel(); + + tensor_loc++; + + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + + for (const auto& chunk : c10::irange(chunks)) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + + threadgroup_loc++; + + const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1; + // Reach the maximum threadgroups per dispatch + const auto blocks_full = threadgroup_loc == kmaxThreadGroups; + + if (tensor_full || blocks_full) { + encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // Reset + threadgroup_loc = 0; + if (chunk == chunks - 1) { + // last chunk + tensor_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + } else { + // reuse the current tensor since the current one isn't done. + metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1]; + + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + for (const auto& d : c10::irange(depth)) { + mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) + usage:MTLResourceUsageWrite | MTLResourceUsageRead]; + } + if (!state_steps.empty()) { + mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors); + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + } + tensor_loc = 1; + } + } + } + } + + if (threadgroup_loc != 0) { + encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + + getMPSProfiler().endProfileKernel(fusedOptimizerPSO); + } + }); +} + +std::pair, id> getAmpCPLState(const std::string& fname); +template +void multi_tensor_apply(const std::string& kernel_name, + std::vector>& tensor_lists, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists must match depth."); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [pipeline, function] = getAmpCPLState(kernel_name); + [computeEncoder setComputePipelineState:pipeline]; + + id argumentEncoder = [function newArgumentEncoderWithBufferIndex:0]; + auto tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int tensor_loc = 0; + int threadgroup_loc = 0; + MetadataArguments metadata_arguments; + std::memset(&metadata_arguments, 0, sizeof(metadata_arguments)); + + for (size_t t = 0; t < num_tensors; t++) { + if (tensor_lists[0][t].numel() == 0) + continue; + + // bind each tensor in this list to the correct slots across depths + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + + // save number of elements for this tensor + metadata_arguments.numels[tensor_loc] = tensor_lists[0][t].numel(); + int currentTensorIndex = tensor_loc; + tensor_loc++; + + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + ((numel % kChunkSize) ? 1 : 0); + + // process tensor in chunks based on max chunk size + for (uint chunk = 0; chunk < chunks; chunk++) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = currentTensorIndex; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + threadgroup_loc++; + + // dispatch when we've filled the threadgroup array or finished the chunks + const bool dispatch_now = (threadgroup_loc == kmaxThreadGroups) || (chunk == chunks - 1); + if (dispatch_now) { + // check for a partial dispatch (i.e. more chunks remain for the current tensor) + bool partial = (chunk != chunks - 1); + uint carried_numels = 0; + if (partial) { + carried_numels = metadata_arguments.numels[currentTensorIndex]; + } + + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, (uint32_t)64), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // prepare for the next batch: reset threadgroup count and create a new buffer + threadgroup_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + if (partial) { + // for a partial dispatch, rebind the partially processed tensor to slot 0 + // so that its metadata is in the correct location + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + 0); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + metadata_arguments.numels[0] = carried_numels; + // the currently processed tensor now lives at index 0 + currentTensorIndex = 0; + tensor_loc = 1; + } else { + tensor_loc = 0; + } + } + } + } + + if (threadgroup_loc != 0) { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, static_cast(64)), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + } + }); +} + +} // namespace at::native::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mtia/EmptyTensor.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mtia/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..afd1e58d40f07837677fd98617afd61e14ec867c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/mtia/EmptyTensor.h @@ -0,0 +1,42 @@ + +#pragma once +#include + +namespace at::detail { + +TensorBase empty_mtia( + IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt); + +TensorBase empty_mtia( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TensorBase empty_mtia(IntArrayRef size, const TensorOptions& options); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + std::optional device_opt); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +} // namespace at::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..c391efd7173e5d7e227f3a6c89492af0eddbac1a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace at::native { + +enum class NESTED_DENSE_OP : uint8_t { ADD, MUL }; + +using nested_dense_elementwise_fn = void (*)( + Tensor& result, + const Tensor& self, + const Tensor& other, + const NESTED_DENSE_OP& op); + +DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h new file mode 100644 index 0000000000000000000000000000000000000000..d0a189a78013a099054fc8df9957f085bc8afd96 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_API Tensor NestedTensor_to_padded_tensor_generic( + const Tensor& t, + double padding, + OptionalIntArrayRef output_size); + +template +Tensor map_nt(const Tensor& nt, Func f) { + auto* nt_impl = get_nested_tensor_impl(nt); + const auto& sizes = nt_impl->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl->get_buffer()), sizes); +} +template +Tensor map_nt_binary(const Tensor& nt_1, const Tensor& nt_2, Func f){ + auto* nt_impl_1 = get_nested_tensor_impl(nt_1); + auto* nt_impl_2 = get_nested_tensor_impl(nt_2); + const auto& sizes = nt_impl_1->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl_1->get_buffer(), nt_impl_2->get_buffer()), sizes); +} + +C10_ALWAYS_INLINE std::pair _check_nested_layer_norm_inputs( + const NestedTensorImpl& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const size_t normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + // Check that the normalized_shape has the exact same sizes as the last dimensions from the NestedTensor input + // Also, compute M and N considering the idiosyncracies of NestedTensors + int64_t N = 1; + for (const auto i: c10::irange(normalized_ndim)) { + TORCH_CHECK( + input.opt_size(-normalized_ndim + i).has_value(), + "normalized_shape extends into irregular dimensions for the nested tensor" + ); + TORCH_CHECK( + normalized_shape[i] == input.opt_size(-normalized_ndim + i), + "The shape at dimension ", + i, + "of normalized_shape doesn't match the input" + ); + N *= normalized_shape[i]; + } + + const int64_t M = input.numel() / N; + + return std::make_pair(M, N); +} + +Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..47119fdd4a1ab10cd7ef3e73491afdb03b5b3838 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -0,0 +1,103 @@ +/** + * Transformer-specific NestedTensor utility functions. + * + * Not co-located with NestedTensor core code yet because they only + * support specific cases needed in transformers. + */ +#pragma once + +#include + +#include +#include + +namespace c10 { +class Scalar; +} // namespace c10 + +namespace at { +class Tensor; +namespace native { +struct NestedTensorImpl; + +// Requires that self is a contiguous NestedTensor, other is not a +// NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self +// must have a consistent last dimension across its included Tensors +// and that dimension must match other.size(0). +Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other); + +// Requires that mat1 is a contiguous NestedTensor, self & mat2 are +// not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1 +// has a consistent last dimension across its included Tensors that +// matches mat2.size(0). +Tensor NestedTensor_times_Tensor_plus_Tensor_addmm( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const c10::Scalar& beta, + const c10::Scalar& alpha, + std::optional use_gelu = std::nullopt); + +Tensor NestedTensor_add_NestedTensor_in_place( + const Tensor& self, + const Tensor& other); + +TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor( + const Tensor& sizes, + int64_t extra_elements); + +Tensor NestedTensor_from_padded_tensor_cpu( + const Tensor& padded, + const NestedTensorImpl& nt); + +TORCH_API Tensor NestedTensor_to_mask(const Tensor& nt, std::optional mask_dim, std::optional mask_dim_length); + +template +void remove_padding_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +template +void remove_padding_transform0213_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +template +void add_padding_kernelLauncher( + T* input, + T* output, + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + const std::vector& output_sizes, + const int batch_size, + const int output_batch_size); + +TORCH_API Tensor flash_attention_helper( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); + +TORCH_API std::tuple mem_efficient_helper_nested_unpacked( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..a9082a7dfa47c9f388f74af7fc137e27a40f639a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h @@ -0,0 +1,40 @@ +#pragma once +#include + +namespace at::native::preprocessing { + +/** + * This function will take nested query, key, and value + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels. + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing( + const Tensor& query, + const Tensor& key, + const Tensor& value); + +/** + * This function will take nested query, key, and value, grad_out, and out + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels backwards. + * We use both functions to avoid having to do the same preprocessing + * for cumulative_sequence_length_q and cumulative_sequence_length_kv + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const Tensor& cumulative_sequence_length_q, + const Tensor& cumulative_sequence_length_kv, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_kv); + +} // namespace at::native::preprocessing diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..6a66a7f43e820af24aa392626b304a6059b9900d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h @@ -0,0 +1,449 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS + +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +#include +#include +#include + +namespace at::native { +struct NestedTensorImpl; + +// The following functions are used to construct nested tensors from buffers and +// metadata. + +inline at::Tensor wrap_buffer(const at::Tensor& buffer, const at::Tensor& nested_sizes) { + TORCH_CHECK( + buffer.dim() == 1, + "Expected given buffer to be 1dim, but got ", + buffer.dim(), + " instead."); + TORCH_CHECK( + buffer.is_contiguous(), "Expected given buffer to be contiguous."); + return at::detail::make_tensor( + buffer, nested_sizes); +} + +// TODO: Figure out if we need a non-moving wrap_buffer() +inline at::Tensor wrap_buffer( + const at::Tensor& buffer, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + buffer.is_contiguous(), "Given buffer must be contiguous."); + return at::detail::make_tensor( + buffer, + std::move(nested_sizes), + std::move(nested_strides), + std::move(storage_offsets)); +} + +inline at::Tensor get_buffer(const at::Tensor& tensor) { + return get_nested_tensor_impl(tensor)->get_buffer(); +} + +/** + * Create a new nested tensor that is a view of a base nested tensor + * + * create_view_tensor calls a specialized constructor that copies the + * keys from base onto the new view tensor being created. + * The storage is shared between the base and the returned view tensor + * + * All callers of this helper must: + * - Only return a view of the input + * - Must be explicit and define a derivative + * + * @param base Base tensor to construct view from. + * @param nested_sizes View tensors' sizes. + * @param nested_strides View tensors' strides. + * @param storage_offsets View tensors' offsets. + * @return A newly constructed view tensor + */ +inline at::Tensor create_nested_view_tensor( + const at::Tensor& base, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT( + base.is_nested(), + "This function can only be used to create nested tensor views"); + TORCH_INTERNAL_ASSERT( + c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::AutogradFunctionality), + "Creating a non differentiable nested tensor view in a CompositeImplicit function is not allowed."); + return at::detail::make_tensor( + c10::TensorImpl::VIEW, + base, + std::move(nested_sizes), + std::move(nested_strides), + std::move(storage_offsets)); +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Helper functions for getting information about a nested tensor's shape. + +int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt); + +// The sizes of the underlying tensors +inline std::vector NestedTensor_get_sizes( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector sizes(ntensors); + if (ntensors == 0) { + return sizes; + } + const Tensor& sizemat = self_ptr->get_nested_sizes(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars has empty sizes + if (orig_dim == 0) { + return sizes; + } + const int64_t* sizemat_ptr = sizemat.const_data_ptr(); + + for (const auto i : c10::irange(ntensors)) { + sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim); + sizemat_ptr += orig_dim; + } + return sizes; +} + +TORCH_API std::vector NestedTensor_get_max_size( + const NestedTensorImpl& nt); + +std::vector NestedTensor_get_max_size_from_size_tensor( + const Tensor& sizes); + +inline std::vector NestedTensor_get_sizes(const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_sizes(self_ptr); +} +// The strides of the underlying tensors +inline std::vector NestedTensor_get_strides( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector strides(ntensors); + if (ntensors == 0) { + return strides; + } + const Tensor& stridemat = self_ptr->get_nested_strides(); + int64_t orig_dim = stridemat.size(1); + // nesting scalars has empty strides + if (orig_dim == 0) { + return strides; + } + const int64_t* stridemat_ptr = stridemat.const_data_ptr(); + for (const auto i : c10::irange(ntensors)) { + strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim); + stridemat_ptr += orig_dim; + } + return strides; +} + +inline std::vector NestedTensor_get_strides( + const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_strides(self_ptr); +} + +inline void check_numel_equals_buffer_size(const at::Tensor& self) { + auto self_impl = get_nested_tensor_impl(self); + TORCH_CHECK( + self.numel() == static_cast(self_impl->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} + +inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) { + TORCH_CHECK( + self_ptr->numel() == static_cast(self_ptr->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} + +// Helper function to get size / stride / offset for a nested/normal tensor. +inline IntArrayRef get_size_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + std::vector tensor_sizes = + NestedTensor_get_sizes(get_nested_tensor_impl(tensor)); + return tensor_sizes[i]; + } else { + return tensor.sizes().slice(1); + } +} + +inline IntArrayRef get_stride_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + std::vector tensor_strides = + NestedTensor_get_strides(get_nested_tensor_impl(tensor)); + return tensor_strides[i]; + } else { + return tensor.strides().slice(1); + } +} + +inline int64_t get_offset_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + int64_t* offsets_ptr = get_nested_tensor_impl(tensor) + ->get_storage_offsets() + .data_ptr(); + return offsets_ptr[i]; + + } else { + int64_t offset = tensor.storage_offset(); + return offset + tensor.strides()[0] * i; + } +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Data structures and functions for generically applying a function on a nested +// tensor. +namespace impl { + +template +struct NestedNode { + NestedNode() = delete; + explicit NestedNode(std::vector children) + : _is_leaf(false), _children(std::move(children)) {} + explicit NestedNode(TensorList children) + : _is_leaf(false), _children(children.vec()) {} + explicit NestedNode(T payload) + : _is_leaf(true), _payload(std::move(payload)) {} + NestedNode(const NestedNode&) = delete; + NestedNode& operator=(const NestedNode&) = delete; + NestedNode(NestedNode&&) noexcept = default; + NestedNode& operator=(NestedNode&&) noexcept = default; + ~NestedNode() = default; + inline bool is_leaf() const { + return _is_leaf; + } + inline size_t degree() const { + return _children.size(); + } + inline const std::vector unbind() const { + return _children; + } + inline T children(size_t i) const { + return _children[i]; + } + inline const T& payload() const { + return _payload; + } + inline T& payload() { + return _payload; + } + + private: + bool _is_leaf; + std::vector _children; + T _payload{}; +}; + +using TensorNode = NestedNode; + +template +class _map; + +template +class _map> { + public: + static A function_one(const F& fn, const Args&... nested_node) { + return fn(nested_node...); + } + static NestedNode function( + const F& fn, + const NestedNode&... nested_node) { + size_t degree = 0; + bool all_leaf = true; + c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) { + all_leaf = all_leaf && (n.is_leaf()); + if (degree > 1 && n.degree() > 1) { + TORCH_CHECK( + degree == n.degree(), "NestedNodes must match in degree."); + } + if (n.degree() > degree) { + degree = n.degree(); + } + return nullptr; + }); + // All NestedNodes just wrap regular objects. + if (all_leaf) { + return NestedNode(std::forward(fn)(nested_node.payload()...)); + } + // Some NestedNodes wrap regular Tensors, some NestedTensors and some other + // types. + std::vector result; + for (size_t i = 0; i < degree; i++) { + auto children = c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&i](auto a) { + static_assert( + c10::guts::is_instantiation_of::value, + "Internal error."); + // Broadcast regular arguments across NestedTensor constituents. + // This could be a Tensor, integer or anything else really. + if (a.is_leaf()) { + return a.payload(); + } + // Broadcast NestedTensors with one constituent. + if (a.degree() == 1 && !a.is_leaf()) { + return a.children(0); + } + TORCH_CHECK(a.degree() > 0, "Internal assert."); + return a.children(i); + }); + std::apply( + [&result, &fn](Args... filtered) { + result.emplace_back(function_one(fn, filtered...)); + }, + std::move(children)); + } + return NestedNode(std::move(result)); + } +}; + +// TODO: Add static assert to verify lambda arguments match nested_node types +template +static inline NestedNode< + typename c10::guts::infer_function_traits::type::return_type> +map(F&& fn, const NestedNode&... nested_node) { + return _map< + F, + typename c10::guts::infer_function_traits::type::return_type, + typename c10::guts::infer_function_traits::type::parameter_types>:: + function(std::forward(fn), nested_node...); +} + +inline TensorNode get_nested_tensor_structure(at::Tensor tensor) { + if (get_nested_tensor_impl_or_null(tensor) == nullptr) { + return TensorNode(std::move(tensor)); + } + return TensorNode(tensor.unbind()); +} + +inline Tensor wrap_tensor_node( + TensorNode tensor_node, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory) { + TORCH_CHECK( + !tensor_node.is_leaf(), "Expected TensorNode to wrap a list of Tensors."); + TensorOptions options_ = + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( + pin_memory); + if (tensor_node.degree() == 0) { + return wrap_buffer(ones({0}, dtype, layout, device), ones({})); + } + + // Fast path: if all tensors are on CPU, have contiguous memory, and the same + // dtype, copying can be done much faster. + bool all_tensors_cpu = true; + bool all_tensors_contiguous = true; + bool all_tensors_same_dtype = true; + auto first_dtype = tensor_node.children(0).dtype(); + std::vector start_offsets(tensor_node.degree()); + start_offsets[0] = 0; + long total_size = 0; + for (const auto i : c10::irange(tensor_node.degree())) { + all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu(); + all_tensors_contiguous = + all_tensors_contiguous && tensor_node.children(i).is_contiguous(); + all_tensors_same_dtype = all_tensors_same_dtype && + (first_dtype == tensor_node.children(i).dtype()); + if (!(all_tensors_cpu && all_tensors_contiguous && + all_tensors_same_dtype)) { + break; + } + if (i > 0) { + start_offsets[i] = + start_offsets[i - 1] + tensor_node.children(i - 1).numel(); + } + total_size += tensor_node.children(i).numel(); + } + + TensorOptions options; + Tensor nt_buffer, nt_sizes; + if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) { + nt_buffer = at::empty({total_size}, tensor_node.children(0).options()); + nt_sizes = at::empty( + {static_cast(tensor_node.degree()), + static_cast(tensor_node.children(0).sizes().size())}, + TensorOptions().dtype(kLong)); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + c10::typeMetaToScalarType(first_dtype), + "create_nt_buffer", + [&]() { + at::parallel_for( + 0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // Only try copying memory if there is more than 0 elements + // for a certain tensor + if (tensor_node.children(i).numel() > 0) { + memcpy( + nt_buffer.mutable_data_ptr() + start_offsets[i], + tensor_node.children(i).const_data_ptr(), + tensor_node.children(i).numel() * sizeof(scalar_t)); + } + } + }); + }); + long sizes_offset = 0; + for (size_t i = 0; i < tensor_node.degree(); ++i) { + auto tensor_sizes = tensor_node.children(i).sizes(); + for (int64_t tensor_size : tensor_sizes) { + nt_sizes.mutable_data_ptr()[sizes_offset++] = tensor_size; + } + } + options = nt_buffer.options().merge_in(options_); + } else { // Slow path + std::vector flat_tensors; + std::vector sizes; + for (const auto i : c10::irange(tensor_node.degree())) { + flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); + sizes.push_back( + tensor(c10::IntArrayRef(tensor_node.children(i).sizes()))); + } + options = flat_tensors[0].options().merge_in(options_); + nt_buffer = at::cat(flat_tensors); + nt_sizes = at::native::stack(sizes); + } + + return wrap_buffer(nt_buffer.to(options), nt_sizes); +} + +} // namespace impl + +// This function is meant to ease rapid operator coverage for +// NestedTensor kernels. It is not meant to be efficient. Use it judiciously. +template +inline at::Tensor map_nested_tensor(F&& fn, A... a) { + return wrap_tensor_node( + impl::map(std::forward(fn), impl::get_nested_tensor_structure(a)...), + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt); +} + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h new file mode 100644 index 0000000000000000000000000000000000000000..93af7669e117025784d5bf9f8f30e1e5a8a9458a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +TORCH_API Tensor& quantize_tensor_per_tensor_affine( + const Tensor& rtensor, + Tensor& qtensor, + double scale, + int64_t zero_point); +TORCH_API Tensor& quantize_tensor_per_channel_affine( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + Tensor zero_points, + int64_t axis); + +TORCH_API Tensor& quantize_tensor_per_channel_float_qparams( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +TORCH_API Tensor& dequantize_tensor_per_tensor_affine( + const Tensor& qtensor, + Tensor& rtensor, + double scale, + int64_t zero_point); +TORCH_API Tensor& dequantize_tensor_per_channel_affine( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + Tensor zero_points, + int64_t axis); +TORCH_API Tensor& dequantize_tensor_per_channel_float_qparams( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_tensor_affine_fn = + void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point); + +using quantize_tensor_per_channel_affine_fn = void (*)( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_channel_float_qparams_fn = void (*)( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using dequantize_tensor_per_tensor_affine_fn = + void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point); + +using dequantize_tensor_per_channel_affine_fn = void (*)( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using dequantize_tensor_per_channel_float_qparams_fn = void (*)( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point); + +using dequantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point); + +DECLARE_DISPATCH( + quantize_tensor_per_tensor_affine_fn, + quantize_tensor_per_tensor_affine_stub) +DECLARE_DISPATCH( + quantize_tensor_per_channel_affine_fn, + quantize_tensor_per_channel_affine_stub) +DECLARE_DISPATCH( + quantize_tensor_per_channel_float_qparams_fn, + quantize_tensor_per_channel_float_qparams_stub) + +DECLARE_DISPATCH( + dequantize_tensor_per_tensor_affine_fn, + dequantize_tensor_per_tensor_affine_stub) +DECLARE_DISPATCH( + dequantize_tensor_per_channel_affine_fn, + dequantize_tensor_per_channel_affine_stub) +DECLARE_DISPATCH( + dequantize_tensor_per_channel_float_qparams_fn, + dequantize_tensor_per_channel_float_qparams_stub) + +DECLARE_DISPATCH( + quantize_tensor_per_tensor_affine_sub_byte_fn, + quantize_tensor_per_tensor_affine_sub_byte_stub) + +DECLARE_DISPATCH( + dequantize_tensor_per_tensor_affine_sub_byte_fn, + dequantize_tensor_per_tensor_affine_sub_byte_stub) + +template +TORCH_API Tensor quantize_tensor( + Tensor rtensor, + Tensor qtensor, + double scale, + int64_t zero_point); +template +TORCH_API Tensor dequantize_tensor( + Tensor qtensor, + Tensor rtensor, + double scale, + int64_t zero_point); + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h new file mode 100644 index 0000000000000000000000000000000000000000..a0cfafdb99054a4f68370934c78005367799c10d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include + +namespace at::native { + +// Quantize a float value into a uint value given scale and zero_point +template +TORCH_API T quantize_val(double scale, int64_t zero_point, float value); +// TODO combine this with quantize_val once the numerics for ARM are aligned +// with it +template +T quantize_val_arm( + const float scale, + const int32_t zero_point, + const float value); +template +void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + T* dst, + size_t count = 8); +template +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value); +template +TORCH_API float dequantize_vec( + double scale, + int64_t zero_point, + const T* src, + float* dst, + size_t count = 8); +template +TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); + +// Given a multiplier and a zero_point, requantize int32_t computed values back +// to quantized values. See comment above +// make_per_tensor_affine_quantizer function for the usage of int64_t +template +TORCH_API DST_T +requantize_from_int(double multiplier, int64_t zero_point, int64_t src); + +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/ConvUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/ConvUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..6f8ff918c1d2f3e421922650161aaa41eda9545f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/ConvUtils.h @@ -0,0 +1,62 @@ +#pragma once +#include +#include + +namespace at::native::quantized { +namespace { +// MakeConvOutputShape used from both CPU and CUDA libraries +// and exporting symbol from torch_cpu would probably take more storage +// than duplicating implementation which likely be inlined away +template +at::SmallVector MakeConvOutputShape( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const torch::List& stride, + const torch::List& padding, + const torch::List& dilation); + +#if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK) +template <> +at::SmallVector MakeConvOutputShape<2>( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const at::List& stride, + const at::List& padding, + const at::List& dilation) { + const int H = input_image_shape[0]; + const int W = input_image_shape[1]; + const int64_t Y_H = + (H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; + const int64_t Y_W = + (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; + return {N, M, Y_H, Y_W}; +} + +template <> +at::SmallVector MakeConvOutputShape<3>( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const at::List& stride, + const at::List& padding, + const torch::List& dilation) { + const int D = input_image_shape[0]; + const int H = input_image_shape[1]; + const int W = input_image_shape[2]; + const int64_t Y_D = + (D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; + const int64_t Y_H = + (H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; + const int64_t Y_W = + (W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1; + return {N, M, Y_D, Y_H, Y_W}; +} + +#endif +} // anonymous namespace +} // namespace at::native::quantized diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/Copy.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..141174233be75a678370d605ff51c65b1575fcfd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/Copy.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace at::native { + +Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src); +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h new file mode 100644 index 0000000000000000000000000000000000000000..e107fb4c62f098e2cff847c9d17cbdea05f72234 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +namespace at { + +struct TensorIterator; + +namespace native { + +using fake_quant_tensor_cachemask_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + float sc, + int64_t z_point, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + const Tensor& sc, + const Tensor& z_point, + const Tensor& fake_quant_enabled, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_learnable_grad_tensor_fn = void (*)( + TensorIterator& iter, + float scale, + float inv_scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub) +DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub) +DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub) + +using fake_quant_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_per_channel_cachemask_fn = void (*)( + TensorIterator &iter, + TensorIterator &iter_mask, + int64_t quant_min, + int64_t quant_max); + +DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub) + +using fake_quant_learnable_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub) + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/IndexKernel.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7811a6db91653414a2ad8cd7b9bf99f20bd6c742 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/IndexKernel.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace at::native { +using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point); +using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point); + +DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub) +DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub) + + +} // at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/PackedParams.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/PackedParams.h new file mode 100644 index 0000000000000000000000000000000000000000..d73bc0adbc4ef953e0580585ab9261700374a45d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/PackedParams.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include + +struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) = 0; + + // out variant of LinearPackedParamsBase::apply + virtual at::Tensor& apply_out( + const at::Tensor& /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/, + at::Tensor& output) { + throw std::runtime_error( + "apply_out is not implemented for this packed " + "parameter type"); + return output; + } + + virtual at::Tensor& apply_relu_out( + const at::Tensor& /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/, + at::Tensor& output) { + throw std::runtime_error( + "apply_relu_out is not implemented for this packed " + "parameter type"); + return output; + } + + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32): + // input -> q* -> dq* -> linear* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // X: float32 Tensor, will be quantized to quint8 in the op + // W_prepack: packed qint8 quantized weight and bias + // Returns: + // Y: float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32): + // input -> q* -> dq* -> linear* -> relu* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // input: float32 Tensor, will be quantized to quint8 in the op + // Returns: + // float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + + virtual at::Tensor apply_dynamic( + at::Tensor input, + bool reduce_range = false) = 0; + virtual at::Tensor apply_dynamic_relu( + at::Tensor input, + bool reduce_range = false) = 0; + + virtual at::Tensor& apply_dynamic_out( + const at::Tensor& /* input */, + at::Tensor& output, + bool /* reduce_range */) { + throw std::runtime_error( + "apply_dynamic_out is not implemented for this packed " + "parameter type"); + return output; + } + virtual at::Tensor& apply_dynamic_relu_out( + const at::Tensor& /* input */, + at::Tensor& output, + bool /* reduce_range */) { + throw std::runtime_error( + "apply_dynamic_relu_out is not implemented for this packed " + "parameter type"); + return output; + } + + virtual std::tuple> unpack() = 0; + + virtual std::optional bias() = 0; + + virtual void set_bias(std::optional /*bias*/) { + throw std::runtime_error( + "set_bias is not implemented for this packed " + "parameter type"); + } +}; + +template +struct ConvPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) = 0; + + virtual std::tuple> unpack() = 0; + + virtual torch::List stride() const = 0; + virtual torch::List padding() const = 0; + virtual torch::List output_padding() const = 0; + virtual torch::List dilation() const = 0; + virtual int64_t groups() const = 0; + virtual bool transpose() const = 0; +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/ACLUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/ACLUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c84406749528eb5bb706afd498198a05ca0a19eb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/ACLUtils.h @@ -0,0 +1,257 @@ +#pragma once + +#include +#if AT_MKLDNN_ACL_ENABLED() + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Utilities for Arm Compute Library (ACL) quantized operations +// Provides interfaces to leverage ACL's accelerated kernels for statically and +// dynamically quantized matmuls (i.e. qlinear and qlinear_dynamic) These are +// utalized through PackedLinearWeightsACL which extends +// PackedLinearWeightsOnednn Note that PackedLinearWeightsACL extends rather +// than replaces PackedLinearWeightsOnednn for AArch64 because ACL currently +// only supports per_tensor weight quantization. +namespace at::native::acl_utils { + +using QuantMatmulCacheKey = std::tuple< + int64_t, // M + bool, // FUSE_RELU + int64_t, // NUM_THREADS + double, // INPUT_SCALE + int64_t, // INPUT_OFFSET + double, // OUTPUT_SCALE + int64_t, // OUTPUT_OFFSET + bool // SIGNED_INPUT + >; + +enum class QuantMatmulCacheKeyIndex { + M, + FUSE_RELU, + NUM_THREADS, + INPUT_SCALE, + INPUT_OFFSET, + OUTPUT_SCALE, + OUTPUT_OFFSET, + SIGNED_INPUT +}; + +// Abstract interface to share common stuff between static/dynamic ACL matmuls. +struct QuantMatmul { + arm_compute::NEGEMMLowpMatrixMultiplyCore gemm; + // key for use in the cache + QuantMatmulCacheKey key; + + QuantMatmul( + int64_t weight_dim_0, + int64_t weight_dim_1, + double weight_scale, + int64_t weight_offset, + int8_t* weight_ptr, + std::optional bias_ptr, + const QuantMatmulCacheKey& cache_key); + + virtual ~QuantMatmul(); + virtual arm_compute::Status validate() = 0; + virtual void configure() = 0; + + protected: + arm_compute::Tensor wei_q_tensor_; + std::optional bia_tensor_; + arm_compute::GEMMInfo gemm_info_; + std::optional relu_info_; +}; + +struct DynamicQuantMatmul : public QuantMatmul { + arm_compute::Tensor src_q_tensor; + arm_compute::Tensor src_tensor; + arm_compute::Tensor dst_tensor; + arm_compute::NEQuantizationLayer quant; + // We need a ReLU layer here (unlike static quantization) because the ReLU + // cannot be "truly" fused with the GEMM through gemm_info in ACL dynamically + // quantized matmuls. + std::optional relu; + + DynamicQuantMatmul( + int64_t weight_dim_0, + int64_t weight_dim_1, + double weight_scale, + int64_t weight_offset, + int8_t* weight_ptr, + std::optional bias_ptr, + const QuantMatmulCacheKey& cache_key); + + ~DynamicQuantMatmul() override; + + arm_compute::Status validate() override; + void configure() override; + + private: + at::Tensor src_q_tensor_orig_; +}; + +struct StaticQuantMatmul : public QuantMatmul { + arm_compute::Tensor src_q_tensor; + arm_compute::Tensor dst_q_tensor; + + StaticQuantMatmul( + int64_t weight_dim_0, + int64_t weight_dim_1, + double weight_scale, + int64_t weight_offset, + int8_t* weight_ptr, + std::optional bias_ptr, + const QuantMatmulCacheKey& cache_key); + + ~StaticQuantMatmul() override; + + arm_compute::Status validate() override; + void configure() override; + + private: + std::optional bia_q_tensor_; + std::optional bia_q_tensor_orig_; +}; + +struct QuantAdd { + arm_compute::Tensor qa_tensor; + arm_compute::Tensor qb_tensor; + arm_compute::Tensor qdst_tensor; + arm_compute::NEArithmeticAddition q_add; + + QuantAdd( + arm_compute::DataType dtype, + const std::vector& input_dims, + double qa_scale, + int64_t qa_offset, + double qb_scale, + int64_t qb_offset, + double dst_scale, + int64_t dst_offset); + + arm_compute::Status validate(); + void configure(); + + private: + arm_compute::ConvertPolicy policy{arm_compute::ConvertPolicy::SATURATE}; +}; + +} // namespace at::native::acl_utils +struct PackedLinearWeightsACL : public PackedLinearWeightsOnednn { + using ACLQuantMatmul = at::native::acl_utils::QuantMatmul; + using ACLDynamicQuantMatmul = at::native::acl_utils::DynamicQuantMatmul; + using ACLStaticQuantMatmul = at::native::acl_utils::StaticQuantMatmul; + using ACLQuantMatmulCacheKey = at::native::acl_utils::QuantMatmulCacheKey; + using ACLQuantMatmulCacheKeyIndex = + at::native::acl_utils::QuantMatmulCacheKeyIndex; + + PackedLinearWeightsACL( + std::unique_ptr weight, + std::optional bias, + at::Tensor orig_weight, + std::optional orig_bias); + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + template + std::shared_ptr get_acl_quant_matmul( + const ACLQuantMatmulCacheKey& key) { + return std::dynamic_pointer_cast( + fetch_or_create_acl_quant_matmul(key)); + } + + private: + int64_t k_; + int64_t n_; + int64_t weight_zero_point_; + double weight_scale_; + + // A 2 element (per layer) cache. Given it's not intended to store more than 2 + // elements, we do not need a fancy implementation. The idea behind it is to + // allow for a (configuration free) fast path for autoregressive + // transformer-like models which usually involve 2 input tensor shapes; one + // for the prefill phase and another for the autoregressive phase + std::array, 2> cache_; + + template + std::shared_ptr fetch_or_create_acl_quant_matmul( + const ACLQuantMatmulCacheKey& key) { + // We're only maintaining a 2 element LRU cache + // hit first + if (cache_[0] != nullptr && cache_[0]->key == key) { + return cache_[0]; + } + // hit second + if (cache_[1] != nullptr && cache_[1]->key == key) { + // Update LRU + std::swap(cache_[0], cache_[1]); + return cache_[0]; + } + // miss -> replace Least Recently Used - i.e. element at index 1 + cache_[1] = create_acl_quant_matmul(key); + std::swap(cache_[0], cache_[1]); + return cache_[0]; + } + + template + std::shared_ptr create_acl_quant_matmul( + const ACLQuantMatmulCacheKey& key) { + std::optional bias_ptr; + if (bias_.has_value()) { + bias_ptr = (float*)bias_.value().get_data_handle(); + } + auto acl_gemm = std::make_shared( + k_, + n_, + weight_scale_, + weight_zero_point_, + (int8_t*)weight_.get()->get_data_handle(), + bias_ptr, + key); + + // validate + auto status = acl_gemm->validate(); + if (status.error_code() != arm_compute::ErrorCode::OK) { + TORCH_WARN( + "Arm Compute Library's Quantized Matmul Validation Failed: " + + status.error_description()); + return nullptr; + } + + // configure + acl_gemm->configure(); + return acl_gemm; + } + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false); + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point); +}; + +#endif // AT_MKLDNN_ACL_ENABLED() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..0643ae3b536d30a4840ccd72dc04857922b3c9b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h @@ -0,0 +1,6 @@ +#include + +namespace at::native { +TORCH_API Tensor +quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point); +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h new file mode 100644 index 0000000000000000000000000000000000000000..e6f47d611a19f4bcf804b63f20fb06be9a2c1f44 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor embeddingbag_byte( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) = 0; + + virtual at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) = 0; + + virtual at::Tensor unpack() = 0; + + virtual int64_t bit_rate() const = 0; + virtual int64_t version() const = 0; +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7722272dfcc270914072602a2a32937fd1c1af72 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h @@ -0,0 +1,463 @@ +#pragma once + +#include +#if AT_MKLDNN_ENABLED() +#include +#include +#include +#if !defined(__powerpc__) +#include +#endif + +#include + +using PrimitiveCacheKey = std::tuple< + double, // input_scale + int64_t, // input_zero_point + std::vector, // input_shape + double, // output_scale + int64_t, // output_zero_point + int64_t, // OMP_number_of_threads + double, // accum_scale + int64_t>; // accum_zero_point + +enum CacheKeyIndex { + InputScale, + InputZeroPoint, + InputShape, + OutputScale, + OutputZeroPoint, + NumOfThreads, +}; + +// Base class of primitive cache +struct PrimitiveCache { + PrimitiveCacheKey key; + + bool hit(const PrimitiveCacheKey& key) { + return this->key == key; + } +}; + +using LinearParams = ideep::matmul_forward_params; +using Conv = dnnl::convolution_forward; +using ConvDesc = dnnl::convolution_forward::primitive_desc; +using ConvParams = ideep::convolution_forward_params; +using Deconv = dnnl::deconvolution_forward; +using DeconvDesc = dnnl::deconvolution_forward::primitive_desc; +using DeconvParams = ideep::deconv_forward_params; + +struct LinearPrimitiveCache : PrimitiveCache { + LinearPrimitiveCache() = default; + + LinearPrimitiveCache( + const PrimitiveCacheKey& key, + const LinearParams& param) { + this->key = key; + this->param = param; + } + + LinearParams param; + + // For dynamic qlinear, scale and zero point + // are set at execution time. So we only need to compare + // the rest part of key. + bool hit_dynamic(const PrimitiveCacheKey& new_key) { + auto const& cached_input_shape = std::get(this->key); + auto const& new_input_shape = std::get(new_key); + return ( + cached_input_shape == new_input_shape && + std::get(this->key) == std::get(new_key)); + } + + LinearParams& get_param() { + return param; + } +}; + +struct ConvPrimitiveCache : PrimitiveCache { + ConvPrimitiveCache() = default; + + ConvPrimitiveCache( + const PrimitiveCacheKey& key, + const ConvParams& params) { + this->key = key; + this->params = params; + } + + ConvParams params; + + ConvParams& get_params() { + return params; + } +}; + +struct DeconvPrimitiveCache : PrimitiveCache { + DeconvPrimitiveCache() = default; + + DeconvPrimitiveCache( + const PrimitiveCacheKey& key, + const DeconvParams& params) { + this->key = key; + this->params = params; + } + + DeconvParams params; + + DeconvParams& get_params() { + return params; + } +}; + +enum PostOps { + NoPostOp, + Relu, + LeakyRelu, + Tanh, + Gelu +}; + + +struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { + PackedLinearWeightsOnednn( + std::unique_ptr weight, + std::optional bias, + at::Tensor orig_weight, + std::optional orig_bias) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)) { + cache_initialized_flag = std::make_unique(); + } + std::unique_ptr weight_; + std::optional bias_; + at::Tensor orig_weight_; + std::optional orig_bias_; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + at::Tensor apply_leaky_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + double negative_slope); + + at::Tensor apply_tanh( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + std::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + private: + LinearPrimitiveCache prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + torch::List post_op_args = torch::List()); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false); + + LinearPrimitiveCache& get_cache() { + return prim_cache; + } +}; + +template +struct PackedConvWeightsOnednn : public ConvPackedParamsBase { + PackedConvWeightsOnednn( + std::unique_ptr weight, + std::optional bias, + at::Tensor orig_weight, + std::optional orig_bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose) { + cache_initialized_flag = std::make_unique(); + } + + std::unique_ptr weight_; + std::optional bias_; + at::Tensor orig_weight_; + std::optional orig_bias_; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + at::Tensor apply_add( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + at::Tensor apply_add_relu( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + ConvPrimitiveCache conv_prim_cache; + DeconvPrimitiveCache deconv_prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + const at::Tensor& input, + const std::optional& accum, + double output_scale, + int64_t output_zero_point); + + ConvPrimitiveCache& get_conv_cache() { + assert(!transpose()); + return conv_prim_cache; + } + + DeconvPrimitiveCache& get_deconv_cache() { + assert(transpose()); + return deconv_prim_cache; + } +}; + +namespace onednn_utils { + +inline ideep::attr_t create_attr_by_post_op( + const std::string_view& binary_post_op, + double binary_alpha, + double input1_scale, + int64_t input1_zero_point, + const ideep::tensor::desc& input1_desc, + const std::string_view& unary_post_op, + const torch::List>& unary_post_op_args, + const std::string_view& unary_post_op_algorithm) { + using ideep::tensor; + if (binary_post_op == "none") { + if (unary_post_op == "relu") { + return ideep::attr_t::fuse_relu(); + } else if (unary_post_op == "leaky_relu") { + TORCH_CHECK( + unary_post_op_args.size() == 1, + "onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args"); + auto alpha = unary_post_op_args[0].value().to(); + return ideep::attr_t::fuse_relu_v2(alpha); + } else if (unary_post_op == "tanh") { + return ideep::attr_t::fuse_tanh(); + } else if (unary_post_op == "gelu") { + TORCH_CHECK( + unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh", + "onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm); + auto post_algorithm = unary_post_op_algorithm == "none" ? + dnnl::algorithm::eltwise_gelu_erf : + dnnl::algorithm::eltwise_gelu_tanh; + return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm); + } else if (unary_post_op == "hardtanh") { + TORCH_CHECK( + unary_post_op_args.size() == 2 && + unary_post_op_args[0].has_value() && + unary_post_op_args[1].has_value(), + "hardtanh is expected to have two scalar input: min_val and max_val"); + auto lower_bound_value = + unary_post_op_args[0].value().to(); + auto upper_bound_value = + unary_post_op_args[1].value().to(); + return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value); + } else if (unary_post_op == "hardswish") { + return ideep::attr_t::fuse_hardswish(); + } else if (unary_post_op == "swish") { + return ideep::attr_t::fuse_swish(); + } else { + TORCH_CHECK( + unary_post_op == "none", + "onednn qlinear: unsupported unary post op ", unary_post_op); + } + } else if (binary_post_op == "sum") { + if (unary_post_op == "none") { + return ideep::attr_t::fuse_sum(input1_scale, input1_zero_point); + } else if (unary_post_op == "relu") { + return ideep::attr_t::residual_with_sum_zero_point(input1_scale, input1_zero_point); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum"); + } + } else if (binary_post_op == "add") { + if (unary_post_op == "none") { + return ideep::attr_t::fuse_binary(ideep::algorithm::binary_add, input1_desc); + } else if (unary_post_op == "relu") { + ideep::post_ops po; + po.append_binary(ideep::algorithm::binary_add, input1_desc); + po.append_eltwise(ideep::algorithm::eltwise_relu, 0, 0); + return ideep::attr_t::attr_post_ops(po); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add"); + } + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported binary post op ", binary_post_op); + } + return ideep::attr_t(); +} + +// ONEDNN requires symmetric quantization of weight +// Use this util function to check. +inline bool is_weight_symmetric_quant( + const at::Tensor& weight, + bool is_transposed_conv) { + bool is_symmetric = true; + const auto qtype = weight.qscheme(); + if (qtype == c10::kPerTensorAffine) { + is_symmetric &= (weight.q_zero_point() == 0); + } else if (qtype == c10::kPerChannelAffine) { + if (is_transposed_conv) { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } else { + auto output_channels = weight.size(0); + for (int i = 0; i < output_channels; ++i) { + auto zp = weight.q_per_channel_zero_points()[i].item(); + is_symmetric &= (zp == 0); + } + } + } else { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } + return is_symmetric; +} + +// When qengine is x86, use this util func to check if onednn kernel +// is preferred than fbgemm's to get better performance. +inline bool should_use_onednn_quant( + const at::Tensor& weight, + bool is_transposed_conv, + int groups, + torch::List output_padding) { + // Performance of onednn is only validated on Linux right now. + // Also, the heuristics for dispatching are based on perf data on Linux. + // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux. + // TODO Support more OSs. +#if !defined(__linux__) + return false; +#else +#if defined(__powerpc__) + constexpr auto vnni_available = true; +#else + const auto vnni_available = cpuinfo_has_x86_avx512vnni(); +#endif + bool w_sym_quant = + is_weight_symmetric_quant(weight, is_transposed_conv); + bool opad_all_zero = + std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }); + return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero; +#endif +} + +} // onednn_utils + +at::Tensor _qconv_prepack_onednn( + at::Tensor weight, // from CPU backend instead of QuantizedCPU + at::Tensor weight_scales, // Weight zero points must be 0 for onednn + double input_scale, + int64_t input_zero_point, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + std::optional> input_shape=std::nullopt); + +#endif // #if AT_MKLDNN_ENABLED() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..36f6140953f6a90e5baf85ab3fc5c5b5b3fad1ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h @@ -0,0 +1,508 @@ +#pragma once + +#ifdef USE_PYTORCH_QNNPACK +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +inline int kPaddingChannels = 8; +struct QnnpackOperatorDeleter { + void operator()(pytorch_qnnp_operator_t op) { + pytorch_qnnp_delete_operator(op); + } +}; + +// PackedWeight struct for QNNPACK stores the original Weight and Bias as +// QNNPACK currently does not support an unpack function. +// For PyTorch Mobile, once the model is scripted and serialized we don't need +// to call unpack, so we can save some memory by checking for this case and free +// the original weights after packing. +// Input scale is set to null in pre-pack step. QNNPACK needs bias quantized +// with input scale which is available at runtime in pytorch. During runtime if +// input scale value changes then we requantize bias with the updated scale. For +// inference we expect the graph to be static so the input scale should not +// change across consecutive inference calls. +struct PackedLinearWeightsQnnp : public LinearPackedParamsBase { + PackedLinearWeightsQnnp( + std::unique_ptr w, + at::Tensor orig_weight, + at::Tensor bias, + std::optional input_scale, + at::Tensor w_scales, + std::vector&& w_zps) + : w(std::move(w)), + orig_weight(std::move(orig_weight)), + bias_(at::native::mobile::allocate_padded_contiguous_if_needed( + bias, bias.suggest_memory_format())), + per_channel_(this->orig_weight.qscheme() == at::kPerChannelAffine), + input_scale(std::move(input_scale)), + w_scales(std::move(w_scales)), + w_zero_points(std::move(w_zps)), + q_scheme(this->orig_weight.qscheme()) { + weight_sizes = this->orig_weight.sizes().vec(); + } + + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias_; + bool per_channel_; + std::optional input_scale; + at::Tensor w_scales; + std::vector w_zero_points; + std::vector requantization_scales; + std::vector weight_sizes; + c10::QScheme q_scheme; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + bool per_channel() const { + return per_channel_; + } + + private: + std::mutex qnnp_mutex_; + +#ifdef USE_XNNPACK + xnnpack_operator xnnp_linear_op; + + template + at::Tensor apply_impl_xnnp( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +#endif // USE_XNNPACK + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range); +}; + +template +struct PackedConvWeightsQnnp : public ConvPackedParamsBase { + PackedConvWeightsQnnp( + std::unique_ptr w, + at::Tensor orig_weight, + at::Tensor bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose, + std::optional input_scale, + std::vector kernel, + at::Tensor w_scale, + std::vector&& w_zps, + bool is_per_channel) + : w(std::move(w)), + orig_weight(std::move(orig_weight)), + bias(std::move(bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose), + is_per_channel_(is_per_channel), + input_scale(input_scale), + kernel_(std::move(kernel)), + w_scales(std::move(w_scale)), + w_zero_points(std::move(w_zps)) { + const bool any_padding = std::any_of( + padding_.begin(), padding_.end(), [](const auto& e) { return e != 0; }); + const size_t kernel_size = + std::accumulate(kernel_.begin(), kernel_.end(), 1, std::multiplies<>()); + + const size_t group_input_channels = transpose + ? this->orig_weight.size(0) / groups + : this->orig_weight.size(1); + const size_t group_output_channels = transpose + ? this->orig_weight.size(1) + : this->orig_weight.size(0) / groups; + + const size_t kernel_depth = kSpatialDim == 3 ? kernel_[0] : 1; + const size_t kernel_height = kernel_[kSpatialDim - 2]; + const size_t kernel_width = kernel_[kSpatialDim - 1]; + + pytorch_qnnp_ukernel_type ukernel_type; + if (transpose_) { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_none; + + const bool has_depthwise_dimensions = + (kSpatialDim == 2 && + ((kernel_height == 3 && kernel_width == 3) || + (kernel_height == 5 && kernel_width == 5))) || + (kSpatialDim == 3 && kernel_height == 3 && kernel_width == 3 && + kernel_depth == 3); + const bool has_depthwise_grouping = + group_input_channels == 1 && group_output_channels == 1 && groups > 1; + + if (has_depthwise_dimensions && has_depthwise_grouping) { + ukernel_type = pytorch_qnnp_ukernel_type_dwconv; + } else if ( + kernel_size == 1 && + std::all_of( + stride_.begin(), + stride_.end(), + [](const auto& e) { return e == 1; }) && + !any_padding) { + ukernel_type = group_input_channels >= SIZE_MAX + ? pytorch_qnnp_ukernel_type_xzp_gemm + : pytorch_qnnp_ukernel_type_gemm; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } + } + + if (is_per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) { + TORCH_INTERNAL_ASSERT( + false, "Per channel quantized weights are not supported for XZP kernels"); + } + + pytorch_qnnp_operator_t convolution{nullptr}; + // Initially all the params are set to zero. + convolution = static_cast( + calloc(1, sizeof(struct pytorch_qnnp_operator))); + if (convolution == nullptr) { + TORCH_INTERNAL_ASSERT( + false, "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + } + + convolution_op = + std::unique_ptr( + convolution); + + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + convolution->ukernel_type = ukernel_type; + convolution->groups = groups; + convolution->group_input_channels = group_input_channels; + convolution->group_output_channels = group_output_channels; + convolution->kernel_depth = kernel_depth; + convolution->kernel_height = kernel_height; + convolution->kernel_width = kernel_width; + convolution->stride_depth = kSpatialDim == 3 ? stride_[0] : 1; + convolution->stride_height = stride_[kSpatialDim - 2]; + convolution->stride_width = stride_[kSpatialDim - 1]; + convolution->dilation_depth = kSpatialDim == 3 ? dilation_[0] : 1; + convolution->dilation_height = dilation_[kSpatialDim - 2]; + convolution->dilation_width = dilation_[kSpatialDim - 1]; + convolution->input_padding_height = padding_[kSpatialDim - 2]; + convolution->input_padding_width = padding_[kSpatialDim - 1]; + convolution->input_padding_depth = kSpatialDim == 3 ? padding_[0] : 0; + convolution->per_channel = is_per_channel_; + convolution->transpose = transpose_; + + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + + size_t zero_size = sizeof(uint8_t) * k_stride; + size_t zero_offset = 0; + + if (transpose_) { + convolution->adjustment_width = output_padding_[1]; + convolution->adjustment_height = output_padding_[0]; + if (group_input_channels < 8) { + zero_size += 8; + zero_offset = 8; + } + } else { + zero_buffer_size = 0; + if (any_padding) { + zero_size = 0; + zero_offset = 0; + if (ukernel_type == pytorch_qnnp_ukernel_type_dwconv) { + const uint32_t cr = pytorch_qnnp_params.q8dw9.cr; + const size_t group_stride = (groups + (cr - 1)) & -cr; + if (groups >= 8) { + zero_size = sizeof(uint8_t) * group_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * group_stride + 8; + zero_offset = sizeof(uint8_t) * 8; + } + } else if ( + ukernel_type == pytorch_qnnp_ukernel_type_conv || + ukernel_type == pytorch_qnnp_ukernel_type_gemm) { + if (group_input_channels >= 8) { + zero_size = sizeof(uint8_t) * k_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * k_stride + 8; + zero_offset = 8; + } + } + } + } + + // NOLINTNEXTLINE(clang-analyzer-optin.portability.UnixAPI) + void* zero_buffer = malloc(zero_size); + if (zero_buffer == nullptr) { + pytorch_qnnp_delete_operator(convolution); + TORCH_INTERNAL_ASSERT( + false, "failed to allocate %zu bytes for zero padding", + zero_size); + } + // Need to set to input zero point + // memset(zero_buffer, input_zero_point, zero_size); + zero_buffer_size = zero_size; + convolution->zero_buffer = zero_buffer; + convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset); + } + + std::unique_ptr convolution_op; + #ifdef USE_XNNPACK + xnnpack_operator xnnp_convolution_op; + #endif // USE_XNNPACK + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + bool transpose_; + bool is_per_channel_; + std::optional input_scale; + std::vector kernel_; + at::Tensor w_scales; + std::vector w_zero_points; + std::vector requantization_scales; + size_t zero_buffer_size; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range=false) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return transpose_; + } + + bool per_channel() const { + return is_per_channel_; + } + + private: + std::mutex qnnp_mutex_; + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + +#ifdef USE_XNNPACK + template + at::Tensor apply_impl_xnnp( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +#endif // USE_XNNPACK +}; + +enum class Activation : uint8_t { NONE = 0, RELU = 1 }; + +template +inline T QuantizeValue(float scale, int32_t zero_point, float value) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + auto r = zero_point + static_cast(std::nearbyint(value / scale)); + r = std::max(r, qmin); + r = std::min(r, qmax); + return static_cast(r); +} + +template +inline std::pair activationLimits( + float scale, + int32_t zero_point, + Activation Ac) { + switch (Ac) { + case Activation::NONE: + return {std::numeric_limits::min(), + std::numeric_limits::max()}; + case Activation::RELU: + return {QuantizeValue(scale, zero_point, 0.0), + std::numeric_limits::max()}; + default: +#ifdef _MSC_VER + __assume(0); +#else + __builtin_unreachable(); +#endif + } +} + +namespace at::native::qnnp_avgpool_helper { +Tensor qnnpack_avg_pool2d( + Tensor input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override); +} // namespace at::native::qnnp_avgpool_helper + +namespace { +[[maybe_unused]] std::vector generate_requantization_scales( + const at::Tensor& weight_scales, + const float input_scale, + const float output_scale, + std::vector& requant_scales) { + // Since weight scale is allocated with padding + // weight_scales.numel() gives us padded num elements. + const auto num_output_channels_padded = weight_scales.numel(); + float *const weight_scales_data = weight_scales.data_ptr(); + if (static_cast(requant_scales.size()) < num_output_channels_padded) { + requant_scales.resize(num_output_channels_padded); + } + for (const auto i : c10::irange(num_output_channels_padded)) { + const auto inverse_output_scale = 1.f /output_scale; + requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale; + TORCH_CHECK( + (requant_scales[i] > 0.0f && std::isnormal(requant_scales[i])), + "failed to create op with requantization scale: ", + requant_scales[i], + ": requantization scale must be finite and positive"); + } + return requant_scales; +} + +[[maybe_unused]] std::pair, at::Tensor> +make_zero_points_and_scales_tensor( + const at::Tensor& weight_contig, + bool transpose = false, + uint32_t groups = 1) { + const int out_ch_idx = transpose ? 1 : 0; + const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1); + // Add 8 to account for bufferring needed by QNNPACK. + const auto num_output_channels_padded = num_output_channels + kPaddingChannels; + const auto qtype = weight_contig.qscheme(); + std::vector weight_zp(num_output_channels_padded, 0); + // Adjust weight zero point, similar to weight data. + if (qtype == at::kPerTensorAffine) { + for (const auto i : c10::irange(num_output_channels)) { + weight_zp[i] = (uint8_t)(weight_contig.q_zero_point() + 128); + } + } else if (qtype == at::kPerChannelAffine) { + TORCH_CHECK( + weight_contig.q_per_channel_zero_points().scalar_type() == at::kLong, + "Per channel zero points dtype must be long int."); + const int64_t* per_channel_zero_points = + weight_contig.q_per_channel_zero_points().data_ptr(); + for (const auto i : c10::irange(num_output_channels)) { + weight_zp[i] = (uint8_t)(per_channel_zero_points[i] + 128); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme."); + } + at:: Tensor weight_scales = + at::empty( + {num_output_channels_padded}, + at::device(at::kCPU).dtype(at::kFloat)); + float *const weight_scales_data = weight_scales.data_ptr(); + if (qtype == at::kPerTensorAffine) { + for (const auto i : c10::irange(num_output_channels)) { + weight_scales_data[i] = weight_contig.q_scale(); + } + } else if (qtype == at::kPerChannelAffine) { + TORCH_CHECK( + weight_contig.q_per_channel_scales().scalar_type() == at::kDouble, + "Per channel scales dtype must be double."); + const double *const per_channel_scales = + weight_contig.q_per_channel_scales().data_ptr(); + for (const auto i : c10::irange(num_output_channels)) { + weight_scales_data[i] = static_cast(per_channel_scales[i]); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme."); + } + for (const auto i : c10::irange(num_output_channels, num_output_channels_padded)) { + weight_scales_data[i] = 1.f; + } + return {weight_zp, weight_scales}; +} +} // namespace + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..e81b0d87916b28b884dfc7edce4a25ad171babe8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h @@ -0,0 +1,240 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace quant_utils { +namespace { + float RawUint16ToFp16(unsigned short value) { + // Convert raw 16 bits half precision floating point number + // to single precision floating point number. + const unsigned short sign_bits = value >> 15; + const unsigned short exponent_bits = value >> 10 & 0x1f; + const unsigned short significand_bits = value & 0x3ff; + + const float sign = sign_bits ? -1 : 1; + const float significand = + 1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10; + const float exponent = exponent_bits - 0xf; + + return sign * std::ldexp(significand, exponent); +} + +template +bool CheckAndSaturate(T max_val, T* element) { + if (*element > max_val) { + *element = max_val; + return true; + } + if (*element < -max_val) { + *element = -max_val; + return true; + } + return false; +} +} +using namespace std; +// A structure to hold quantization parameters 'scale' and 'zero_point'. +// The meaning of these values is as the constants in the quantization equation +// +// real_value = scale * (quantized_value - zero_point) +// +// In other words, 'zero_point' is the quantized value that corresponds +// to the real value 0, and 'scale' is the difference of real values +// corresponding to consecutive quantized values. +struct TensorQuantizationParams { + double scale; + std::int32_t zero_point; + int precision; +}; + +// Use fp16_min as the small scale cutoff because we don't want to use scales in +// fp16 subnormal range. This is to be consistent with Glow and FakeLowP +// implementation for NNPI. +constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + +// Following implementation should be identical to fbgemm::ChooseQuantizationParams +inline TensorQuantizationParams ChooseQuantizationParams( + float min, + float max, + int32_t qmin, + int32_t qmax, + bool preserve_sparsity = false, + bool force_scale_power_of_two = false, + bool reduce_range = false) { + TORCH_CHECK( + min <= max, + "In ChooseQuantizationParams, min should be less than or equal to max"); + + if (reduce_range) { + qmin = qmin/2; + qmax = qmax/2; + } + if (min < 0 && max > 0 && preserve_sparsity) { + int symmetric_qmin = -((qmax - qmin) / 2 + 1); + int symmetric_qmax = (qmax - qmin) / 2; + double max_scale = + std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax)); + min = max_scale * symmetric_qmin; + max = max_scale * symmetric_qmax; + } + + // We extend the [min, max] interval to ensure that it contains 0. + // Otherwise, we would not meet the requirement that 0 be an exactly + // representable value. + min = std::min(min, 0.f); + max = std::max(max, 0.f); + + TORCH_CHECK( + qmin < qmax, + "In ChooseQuantizationParams, qmin should be less than qmax"); + + // Use double precision for intermediate computation but use single precision + // in final number to reflect the actual number used during quantization. + double scale = (static_cast(max) - min) / (qmax - qmin); + // If scale is 0 or too small so its reciprocal is infinity, we arbitrary + // adjust the scale to 0.1 . We want to avoid scale's reciprocal being + // infinity because some of fbgemm code pre-computes scale's reciprocal to do + // multiplication instead of division in the time critical part of code. + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + TORCH_CHECK(scale > 0, "quantization scale should be > 0"); + + if (force_scale_power_of_two) { + if (scale < 1) { + scale = 1.0 / (1 << static_cast(floor(log(1.0 / scale) / log(2)))); + } else { + scale = 1 << static_cast(ceil(log(scale) / log(2))); + } + } + + // Cut off small scale + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust the min and max based on the new scale + if (min == 0.0f) { + max = SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else if (max == 0.0f) { + min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min *= amplifier; + max *= amplifier; + } + } + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + double zero_point_from_min = qmin - min / static_cast(scale); + double zero_point_from_max = qmax - max / static_cast(scale); + double zero_point_from_min_error = + std::abs(qmin) - std::abs(min / static_cast(scale)); + double zero_point_from_max_error = + std::abs(qmax) - std::abs(max / static_cast(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // for symmetric quantization (preserve_sparsity == true), we force zero_point + // to be a middle value between qmin and qmax. + // If either min or max is 0, then we just use 0 as zero_point. + if (min < 0 && max > 0 && preserve_sparsity) { + initial_zero_point = static_cast(qmin + qmax) / 2; + } + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with zero + // padding). + int32_t nudged_zero_point = 0; + if (initial_zero_point < qmin) { + nudged_zero_point = qmin; + } else if (initial_zero_point > qmax) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = nearbyint(initial_zero_point); + } + + TensorQuantizationParams result; + result.scale = scale; + result.zero_point = nudged_zero_point; + return result; +} + +// This function helps to convert the Conv1D dimensions usable by the Conv2d op. +constexpr int64_t kConv1dSqueezeDim = 0; +[[maybe_unused]] static torch::List MakeArgForConv1d( + const torch::List& arg, + int64_t base_value) { + TORCH_CHECK(!arg.empty(), "Argument must have elements."); + torch::List result({arg.get(0), base_value}); + if (arg.size() == 1) { + result[1] = arg.get(0); + } else { + result[1] = arg.get(1); + } + result[kConv1dSqueezeDim] = base_value; + return result; +} + +// The range for using FP16 quantization of weights requires that the elements +// should be in the range of [5.96e-8, 65504]. If it is out of range, then the +// number will be saturated to max or min representable values by FP16. +inline void HandleWeightsSaturation(int64_t N, float* weight) { + const float kFp16Max = RawUint16ToFp16(0x7BFF); + bool found_out_of_range = false; + for (const auto i : c10::irange(N)) { + bool saturate = CheckAndSaturate(kFp16Max, weight + i); + if (saturate) { + found_out_of_range = true; + } + } + if (found_out_of_range) { + TORCH_WARN("FOUND weight out of range "); + } +} + +// Util function for quantizing bias. +inline at::Tensor QuantizeBias( + bool is_per_channel, + const at::Tensor& bias, + const at::Tensor& weight_contig, + double input_scale) { + at::Tensor qbias; + if (is_per_channel) { + auto bias_quant_scales = + weight_contig.q_per_channel_scales() * input_scale; + auto bias_zp = at::zeros(bias_quant_scales.sizes(), c10::kInt); + qbias = at::native::quantize_per_channel( + bias, bias_quant_scales, bias_zp, 0, c10::kQInt32); + } else { + qbias = at::native::quantize_per_tensor( + bias, weight_contig.q_scale() * input_scale, 0, c10::kQInt32); + } + return qbias; +} + +} // namespace quant_utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h new file mode 100644 index 0000000000000000000000000000000000000000..c7fcedb68a1e0b01293434d581b15afd75b9a7b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h @@ -0,0 +1,282 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace at::native { + +using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, + const Scalar& /*negval_*/); +using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */); +using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point); +using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qclamp_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& min, + const Scalar& max, + at::Tensor& /*qy*/); +using qclamp_minmax_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& /*min or max*/, + at::Tensor& /*qy*/); +using qthreshold_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& threshold, + const Scalar& value, + at::Tensor& /*qy*/); +using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qelu_fn = void(*)( + const at::Tensor& /*qx*/, + const Scalar& /*alpha*/, + const Scalar& /*scale*/, + const Scalar& /*input_scale*/, + at::Tensor& /*qy*/); +using qbinary_fn = + void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/); +using qadd_scalar_fn = + void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Scalar& other /*other*/); +using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qdropout_fn = void(*)( + const at::Tensor& /*qx*/, + const Scalar& /*p*/, + bool training /*training*/, + at::Tensor& /*qy*/); +using qmaxpool_2d_fn = void (*)( + const Tensor& qx, + int64_t iC, // input/output channels + int64_t iH, + int64_t iW, // input sizes + int64_t oH, + int64_t oW, // output sizes + int64_t kH, + int64_t kW, // kernel size + int64_t sH, + int64_t sW, // strides + int64_t pH, + int64_t pW, // padding + int64_t dH, + int64_t dW, // dilation + Tensor& qy); +using qmaxpool_3d_fn = void (*)( + const Tensor& qx, + int64_t iC, // input/output channels + int64_t iT, + int64_t iH, + int64_t iW, // input sizes + int64_t oT, + int64_t oH, + int64_t oW, // output sizes + int64_t kT, + int64_t kH, + int64_t kW, // kernel size + int64_t sT, + int64_t sH, + int64_t sW, // strides + int64_t pT, + int64_t pH, + int64_t pW, // padding + int64_t dT, + int64_t dH, + int64_t dW, // dilation + Tensor& qy); +using qadaptive_avg_pool2d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t sizeB, + int64_t sizeC, + int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW, + int64_t istrideB, + int64_t istrideC, + int64_t istrideH, + int64_t istrideW); +using qadaptive_avg_pool3d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t sizeB, + int64_t sizeC, + int64_t isizeD, + int64_t isizeH, + int64_t isizeW, + int64_t osizeD, + int64_t osizeH, + int64_t osizeW, + int64_t istrideB, + int64_t istrideC, + int64_t istrideD, + int64_t istrideH, + int64_t istrideW); +using qavg_pool2d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t nBatch, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t outputWidth, + int64_t outputHeight, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + bool count_include_pad, + std::optional divisor_override); + +using qavg_pool3d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t nBatch, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t inputDepth, + int64_t outputWidth, + int64_t outputHeight, + int64_t outputDepth, + int kW, + int kH, + int kD, + int dW, + int dH, + int dD, + int padW, + int padH, + int padD, + bool count_include_pad, + std::optional divisor_override); + +using qupsample_bilinear2d_fn = void (*)( + Tensor& output, + const Tensor& input, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners, + std::optional scales_h, + std::optional scales_w); + +using qcat_nhwc_fn = Tensor (*)( + const MaterializedITensorListRef& qxs, + int64_t dim, + double scale, + int64_t zero_point); +using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool); + +using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, int64_t, int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&); + +using qnormalize_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + bool /* affine_per_channel */, + int /* num_channels */, + int /* num_groups */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */); + +using qmean_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* opt_dim */, + bool /* keepdim */, + std::optional /* opt_dtype */, + Tensor& /* Y */); + +using qstd_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* dim */, + const std::optional& /* correction */, + bool /* keepdim */, + Tensor& /* Y */); + +using qnormalize_nhwc_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + bool /* affine_per_channel */, + int /* num_channels */, + int /* num_groups */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */); + +using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, + const Tensor& /*qw*/); + +using qbinary_eltwise_cpu_fn = void (*)( + Tensor& /*out*/, + const Tensor& /*qx*/, + double /*qx_scale*/, + int64_t /*qx_zero_point*/, + const Tensor& /*qy*/, + double /*qy_scale*/, + int64_t /*qy_zero_point*/, + double /*output_scale*/, + int64_t /*output_zero_point*/); + +using qbatch_norm_cpu_fn = void(*)( + int64_t /*N*/, + int64_t /*C*/, + int64_t /*H * W*/, + int64_t /*in_zero_point*/, + int64_t /*out_zero_point*/, + const Tensor& /*input*/, + const Tensor& /*a*/, + const Tensor& /*b*/, + Tensor& /*output*/); + +DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub) +DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub) +DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub) +DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub) +DECLARE_DISPATCH(qavg_pool2d_fn, qavg_pool2d_nhwc_stub) +DECLARE_DISPATCH(qavg_pool3d_fn, qavg_pool3d_nhwc_stub) +DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub) +DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub) +DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub) +DECLARE_DISPATCH(qbinary_fn, qadd_stub) +DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub) +DECLARE_DISPATCH(qbinary_fn, qmul_stub) +DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub) +DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub) +DECLARE_DISPATCH(qclamp_fn, qclamp_stub) +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub) +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub) +DECLARE_DISPATCH(qelu_fn, qelu_stub) +DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub) +DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub) +DECLARE_DISPATCH(qdropout_fn, qdropout_stub) +DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub) +DECLARE_DISPATCH(qmaxpool_3d_fn, qmaxpool_3d_nthwc_stub) +DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub) +DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub) +DECLARE_DISPATCH(qrelu_fn, qrelu_stub) +DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub) +DECLARE_DISPATCH(qgelu_fn, qgelu_stub) +DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub) +DECLARE_DISPATCH(qtanh_fn, qtanh_stub) +DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub) +DECLARE_DISPATCH(qtopk_fn, qtopk_stub) +DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub) +DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub) +DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub) +DECLARE_DISPATCH(qprelu_fn, qprelu_stub) +DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qmul_tensor_cpu_stub) +DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_tensor_cpu_stub) +DECLARE_DISPATCH(qbinary_eltwise_cpu_fn, qadd_relu_tensor_cpu_stub) +DECLARE_DISPATCH(qbatch_norm_cpu_fn, qbatch_norm_cpu_stub) + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..ea91cdffdf8a1335cd8a91f3af9db8026f4d5649 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h @@ -0,0 +1,17 @@ +#pragma once + +#ifdef USE_RUY_QMATMUL + +#include + +namespace at::native::ruy_utils { + +ruy::Context* get_ruy_context(); + +void quantize_multiplier(double scale, + int* multiplier_fixedpoint, + int* multiplier_exponent); + +} // namespace at::native::ruy_utils + +#endif // USE_RUY_QMATMUL diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..05616337dc58d23513cd3423dafa440b4e19ec6d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h @@ -0,0 +1,331 @@ +#pragma once + +#ifdef USE_XNNPACK +#include + +#include +#include + +using xnnpack_operator = at::native::xnnpack::Operator; + +namespace at::native::xnnp_utils { + +/* + * Return shape in the same order as the memory format + * e.g. channels_last will return NHWC instead of NCHW + */ +std::vector get_mem_format_aware_shape(const at::Tensor& in); + +/* + * Input is always int8_t, output can be [int8_t, uint8_t]. + * input + offset = output + * int8_t + 128 = uint8_t + * int8_t + 0 = int8_t + */ +template +void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out); + +template +Tensor convert_conv_weights_to_channel_last_tensor( + const at::Tensor& src, + int groups, + bool transpose); + +/* + * Series of create wrapper functions to call xnn_create_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_convolution2d_nhwc( + uint32_t pad_top, + uint32_t pad_right, + uint32_t pad_bottom, + uint32_t pad_left, + uint32_t kernel_h, + uint32_t kernel_w, + uint32_t stride_h, + uint32_t stride_w, + uint32_t dilation_h, + uint32_t dilation_w, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t ip_chan_stride, + size_t op_chan_stride, + int8_t izp, + float ip_scale, + int8_t kzp, + const float* k_scales, + const int8_t* kernel, + const int32_t* bias, + int8_t ozp, + float op_scale, + int8_t op_min, + int8_t op_max, + uint32_t flags, + xnn_operator_t* op, + bool per_channel, + bool transpose) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero." + "But got: ", kzp); + + if (transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_create_deconvolution2d_nhwc_qs8( + pad_top, /* uint32_t output_padding_top */ + pad_right, /* uint32_t output_padding_right */ + pad_bottom, /* uint32_t output_padding_bottom */ + pad_left, /* uint32_t output_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t stride_height */ + stride_w, /* uint32_t stride_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels */ + ip_chan_stride, /* size_t input_pixel_stride */ + op_chan_stride, /* size_t output_pixel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* deconvolution_op_out */ + + } + + if (!per_channel) { + return xnn_create_convolution2d_nhwc_qs8( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } else { /* per_channel */ + return xnn_create_convolution2d_nhwc_qs8_qc8w( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales, /* const float* kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } +} + +/* + * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_convolution2d_nhwc( + xnn_operator_t op, + size_t batch, + size_t in_h, + size_t in_w, + pthreadpool_t pt_pool, + bool per_channel = false, + bool transpose = false, + uint32_t adj_h = 0, + uint32_t adj_w = 0) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_reshape_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + adj_h, /* uint32_t adjustment_height */ + adj_w, /* uint32_t adjustment_width */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } + + size_t workspace_size = SIZE_MAX; + size_t workspace_alignment = SIZE_MAX; + + if (!per_channel) { + return xnn_reshape_convolution2d_nhwc_qs8( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } else { /* per_channel */ + return xnn_reshape_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } +} + + +/* + * Series of setup wrapper functions to call xnn_setup_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_convolution2d_nhwc( + xnn_operator_t op, + const int8_t* inp, + int8_t* outp, + bool per_channel = false, + bool transpose = false) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + + return xnn_setup_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } + + if (!per_channel) { + return xnn_setup_convolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } else { /* per_channel */ + return xnn_setup_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } +} + + +/* + * Series of wrapper functions to call xnn_create* and xnn_setup* + * functions for linear + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_fully_connected_nc( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + int8_t input_zero_point, + float input_scale, + int8_t kernel_zero_point, + float kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_operator_t* fully_connected_op_out) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero." + "But got: ", kernel_zero_point); + return xnn_create_fully_connected_nc_qs8( + input_channels, /* size_t input_channels */ + output_channels, /* size_t output_channels */ + input_stride, /* size_t input_stride */ + output_stride, /* size_t output_stride */ + input_zero_point, /* int8_t input_zero_point */ + input_scale, /* float input_scale */ + kernel_scale, /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + output_zero_point, /* int8_t output_zero_point */ + output_scale, /* float output_scale */ + output_min, /* int8_t output_min */ + output_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t */ + fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_fully_connected_nc( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) { + return xnn_reshape_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + batch_size, /* size_t batch_size */ + threadpool); /* pthreadpool_t threadpool */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_fully_connected_nc( + xnn_operator_t fully_connected_op, + const int8_t* input, + int8_t* output) { + return xnn_setup_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + input, /* const int8_t* input */ + output /* int8_t* output */ + ); +} + +} // namespace at::native::xnnp_utils + +#endif // USE_XNNPACK diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h new file mode 100644 index 0000000000000000000000000000000000000000..3edd398fa789a7399700ce12d320025e6af4cfca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h @@ -0,0 +1,417 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#if !defined(__s390x__) && !defined(__powerpc__) +#include +#endif + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + + +#include + +/* Convolution prepacked parameters serialization. + * + * Version 1 + * + * - Fields: + * 1. weight + * 2. bias + * 3. stride x kSpatialDim + * 4. padding x kSpatialDim + * 5. dilation x kSpatialDim + * 6. groups + * + * Version 2 + * + * - Fields: + * 0. version (string) + * 1. list of non-optional tensors + * 0: packed parameters (int16_t) + * - kSpatialDim + * - stride x kSpatialDim + * - padding x kSpatialDim + * - dilation x kSpatialDim + * - output_padding x kSpatialDim + * - groups + * - transpose (0 or 1) + * 1: weight + * 2. list of optional tensors + * 0: bias + * + * Version 3 + * + * - Fields: + * 0. version (int64_t) + * 1. list of int64_t configuration values + * - kSpatialDim + * - stride x kSpatialDim + * - padding x kSpatialDim + * - dilation x kSpatialDim + * - output_padding x kSpatialDim + * - groups + * - flags (bitmask) + * - (1 << 0) transpose (1 = yes) + * 2. list of optional tensors + * 0: None (helps with type inference) + * 1: weight (this must be present) + * 2: bias + */ + +using ConvParamsSerializationTypeV2 = std::tuple< + // version, for versions 2 and up + std::string, + // non-optional tensors + std::vector, + // optional tensors + std::vector>>; + +using ConvParamsSerializationTypeV3 = std::tuple< + // version, int for versions 3 and up + int64_t, + // configuration values + std::vector, + // optional tensors + std::vector>>; + +// Parses any historical conv packed params format into +// the current format. +template +ConvParamsSerializationTypeV3 parse_conv_serialized_state(const c10::IValue& v) { + + // determine the version based on IValue contents + int version = -1; + if (v.isTuple()) { + const auto& elements = v.toTupleRef().elements(); + if (!elements.empty()) { + auto firstElement = elements[0]; + if (firstElement.isTensor()) { + version = 1; + } else if (firstElement.isString()) { + const std::string& version_str = firstElement.toStringRef(); + // note: not parsing the string to automatically handle bad + // inputs + if (version_str == "2") { + version = 2; + } + } else if (firstElement.isInt()) { + auto raw_version = firstElement.toInt(); + if (raw_version == 3) { + version = 3; + } + } + } + } + TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version"); + + if (version == 1) { + // version 1 - convert to version 3 manually + + const auto& elements = v.toTupleRef().elements(); + + at::Tensor weight = elements[0].toTensor(); + std::optional bias = elements[1].toOptional(); + torch::List stride_x_kSpatialDim = elements[2].toTensorList(); + torch::List padding_x_kSpatialDim = elements[3].toTensorList(); + torch::List dilation_x_kSpatialDim = elements[4].toTensorList(); + at::Tensor groups = elements[5].toTensor(); + + std::vector config_vals; + config_vals.reserve( + stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() + + dilation_x_kSpatialDim.size() + kSpatialDim + 3); + config_vals.push_back(kSpatialDim); + for (const auto i : c10::irange(stride_x_kSpatialDim.size())) { + auto const & stride = stride_x_kSpatialDim.get(i); + config_vals.push_back(stride[0].item()); + } + for (const auto i : c10::irange(padding_x_kSpatialDim.size())) { + auto const &padding = padding_x_kSpatialDim.get(i); + config_vals.push_back(padding[0].item()); + } + for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) { + auto const &dilation = dilation_x_kSpatialDim.get(i); + config_vals.push_back(dilation[0].item()); + } + // output_padding does not exist in v1, so we fill in a default value + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + config_vals.push_back(0); + } + config_vals.push_back(groups[0].item()); + // transpose does not exist in v1, so we fill in a default value + config_vals.push_back(0); + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); + } else if (version == 2) { + // version 2 + const auto& elements = v.toTupleRef().elements(); + std::vector non_optional = elements[1].toTensorList().vec(); + std::vector> optional; + + if (elements[2].isTensorList()) { + for (const auto& elem : elements[2].toTensorList()) { + optional.emplace_back(static_cast(elem)); + } + } else { + for (const auto& elem : elements[2].toList()) { + optional.emplace_back(static_cast(elem).toOptional()); + } + } + // create default optional value for bias + if (optional.empty()) { + optional.emplace_back(); + } + + auto config_a = non_optional[0].accessor(); + std::vector config_vals; + config_vals.reserve(config_a.size(0)); + for (const auto i : c10::irange(config_a.size(0))) { + config_vals.emplace_back(config_a[i]); + } + + auto weight = non_optional[1]; + auto bias = optional[0]; + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); + } else if (version == 3) { + return v.to(); + } else { + TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ", + version); + } +} + +#define QCONV_SERIALIZATION_VERSION 2 + +#if QCONV_SERIALIZATION_VERSION == 2 +using ConvParamsSerializationType = ConvParamsSerializationTypeV2; + +template +ConvParamsSerializationTypeV2 serialize_conv( + const c10::intrusive_ptr>& params) { + + std::string version = "2"; + std::vector non_optional; + std::vector> optional; + + // create a packed int8_t tensor for conv params + std::vector params_vec; + params_vec.push_back(kSpatialDim); + auto stride = params->stride().vec(); + params_vec.insert(params_vec.end(), stride.begin(), stride.end()); + auto padding = params->padding().vec(); + params_vec.insert(params_vec.end(), padding.begin(), padding.end()); + auto dilation = params->dilation().vec(); + params_vec.insert(params_vec.end(), dilation.begin(), dilation.end()); + auto output_padding = params->output_padding().vec(); + params_vec.insert(params_vec.end(), output_padding.begin(), + output_padding.end()); + params_vec.push_back(params->groups()); + params_vec.push_back(params->transpose()); + int64_t vec_size = params_vec.size(); + at::Tensor params_tensor = at::from_blob( + params_vec.data(), {vec_size}, + at::TensorOptions().dtype(at::kShort)) + // clone to retain ownership of the data + .clone(); + + auto [weight, bias] = params->unpack(); + + non_optional.emplace_back(std::move(params_tensor)); + non_optional.emplace_back(std::move(weight)); + optional.emplace_back(std::move(bias)); + + return std::tie(version, non_optional, optional); +} + +#elif QCONV_SERIALIZATION_VERSION == 3 +using ConvParamsSerializationType = ConvParamsSerializationTypeV3; + +template +ConvParamsSerializationTypeV3 serialize_conv( + const c10::intrusive_ptr>& params) { + std::vector config_vals; + config_vals.push_back(kSpatialDim); + auto stride = params->stride().vec(); + config_vals.insert(config_vals.end(), stride.begin(), stride.end()); + auto padding = params->padding().vec(); + config_vals.insert(config_vals.end(), padding.begin(), padding.end()); + auto dilation = params->dilation().vec(); + config_vals.insert(config_vals.end(), dilation.begin(), dilation.end()); + auto output_padding = params->output_padding().vec(); + config_vals.insert(config_vals.end(), output_padding.begin(), + output_padding.end()); + config_vals.push_back(params->groups()); + config_vals.push_back(params->transpose()); + + auto [weight, bias] = params->unpack(); + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); +} + +#else +#error "Invalid qconv serialization version." +#endif + +template +c10::intrusive_ptr> deserialize_conv( + ConvParamsSerializationTypeV3 state) { + auto & [version, config_vals, tensors] = state; + TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version); + + TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size()); + auto & weight = tensors[1]; + auto & bias [[maybe_unused]] = tensors[2]; + TORCH_INTERNAL_ASSERT(weight.has_value(), "Weight should always be present in serialized qconv."); + + torch::List stride, padding, output_padding, dilation; + // skip kSpatialDim + int idx = 1; + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + stride.emplace_back(config_vals.at(idx)); + idx++; + } + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + padding.emplace_back(config_vals.at(idx)); + idx++; + } + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + dilation.emplace_back(config_vals.at(idx)); + idx++; + } + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + TORCH_INTERNAL_ASSERT( + idx < static_cast(config_vals.size()), + "Unexpected index = ", + idx, + " for config_vals of size ", + config_vals.size()); + output_padding.emplace_back(config_vals.at(idx)); + idx++; + } + int64_t groups [[maybe_unused]] = config_vals.at(idx); + idx++; + int64_t flags [[maybe_unused]] = config_vals.at(idx); + idx++; + TORCH_INTERNAL_ASSERT(idx == static_cast(config_vals.size()), + "Unexpected length of config_vals, expected ", + idx, + " got ", + config_vals.size()); + + bool transpose [[maybe_unused]] = flags & (1 << 0); + + int64_t other_flags = flags & ~(1 << 0); + TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, "."); + + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::X86) { +#if AT_MKLDNN_ENABLED() + bool use_onednn = onednn_utils::should_use_onednn_quant( + weight.value(), transpose, groups, output_padding); + if (use_onednn) { + return PackedConvWeightsOnednn::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif + return PackedConvWeight::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } // x86 +#endif + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return PackedConvWeight::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // USE_FBGEMM +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + TORCH_CHECK( + kSpatialDim == 2, + "prepack/__setstate__: QNNPACK only supports Conv2d " + "now."); + return PackedConvWeightsQnnp::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return PackedConvWeightsOnednn::prepack( + std::move(weight.value()), + std::move(bias), + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // AT_MKLDNN_ENABLED() +TORCH_CHECK( + false, + "Didn't find engine for when deserializing ConvPackedParams: ", + toString(ctx.qEngine())); +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..e6d86cf03df130ba4a42efa61ab557a90f4fd898 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include + +#ifdef USE_FBGEMM +#include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winconsistent-missing-destructor-override") +#include +C10_DIAGNOSTIC_POP() +#include + +// The struct for the packed weight matrix (PackBMatrix) and the corresponding +// column offsets used for the fully connect layer, which are both prepared in +// the prepacking step to save the computations in the inference. Note the +// column offsets include the sum of the B columns as well as the scalar term +// B_zero_point * K, whereas the row offsets created by +// PackAWithQuantRowOffset/PackAWithIm2Col/PackAWithRowOffset are only the sum +// of the A rows. The column offsets are needed for the asymmetric quantization +// (affine quantization) of input matrix. +// Note that in JIT mode we can think of a way to fuse col_offsets with bias. +struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { + PackedLinearWeight( + std::unique_ptr> w, + std::optional bias, + std::vector col_offsets, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias_(std::move(bias)), + col_offsets(std::move(col_offsets)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(std::move(q_scheme)) {} + std::unique_ptr> w; + std::optional bias_; + std::vector col_offsets; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor& apply_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor& apply_relu_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + private: + template + at::Tensor& apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output); + + template + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32_impl( + const at::Tensor& input, + double input_scale, + int64_t input_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false); +}; + +struct TORCH_API PackedLinearWeightFp16 : public LinearPackedParamsBase { + PackedLinearWeightFp16( + std::unique_ptr w, + std::optional bias) + : w(std::move(w)), bias_(std::move(bias)) {} + + std::unique_ptr w; + std::optional bias_; + + at::Tensor apply( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + at::Tensor apply_relu( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor& apply_dynamic_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + at::Tensor& apply_dynamic_relu_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + void set_bias(std::optional bias) override; + + private: + template + at::Tensor& apply_dynamic_impl(const at::Tensor& input, at::Tensor& output); +}; + +template +struct TORCH_API PackedConvWeight : public ConvPackedParamsBase { + PackedConvWeight( + std::unique_ptr> w, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose, + std::vector col_offsets, + std::vector kernel, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias(std::move(bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose), + col_offsets(std::move(col_offsets)), + kernel(std::move(kernel)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(q_scheme) {} + + std::unique_ptr> w; + std::optional bias; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + std::vector col_offsets; + std::vector kernel; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + const float* GetBiasData(at::Tensor* bias); + + void GetQuantizationParams( + float act_scale, + float out_scale, + std::vector* output_multiplier_float, + std::vector* act_times_w_scale); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +// PackWeight: Convert the weight from uint8 to int8. +inline void convert_uint8_int8( + int len, + const uint8_t* src_uint8, + int8_t* dst_int8) { + for (const auto i : c10::irange(len)) { + dst_int8[i] = static_cast(static_cast(src_uint8[i]) - 128); + } +} + +// UnpackWeight: Convert the weight from int8 to uint8. +inline void convert_int8_uint8( + int len, + const int8_t* src_int8, + uint8_t* dst_uint8) { + for (const auto i : c10::irange(len)) { + dst_uint8[i] = + static_cast(static_cast(src_int8[i]) + 128); + } +} + +namespace at::native::fbgemm_utils { + +template +fbgemm::conv_param_t MakeFbgemmConvParam( + int N, + int C, + int M, + const std::vector& image_shape, + int groups, + const std::vector& kernels, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::vector& output_padding = std::vector(kSpatialDim, 0), + bool transposed = false); + +// TODO: Remove functions below when ChannelsLast3d is ready. +Tensor MakeStridedQTensorCPU( + const IntArrayRef& sizes, + const IntArrayRef& strides, + const TensorOptions& options, + QuantizerPtr quantizer); + +Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + double scale, + int64_t zero_point); + +Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + const Tensor& scales, + const Tensor& zero_points); + +Tensor ConvertToChannelsLast3dTensor(const Tensor& src); + +template +Tensor TransposeConvTensorUnpackConversion(const Tensor& src, int groups); + +template +Tensor ConvertConvWeightsToChannelLastTensor( + const at::Tensor& src, + int groups, + bool transpose); +} // at::native::namespace fbgemm_utils + +#endif // USE_FBGEMM + +struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { + PackedEmbeddingBagWeight( + at::Tensor packed_w, + std::vector w_scale, + std::vector w_zp, + int64_t bit_rate, + c10::QScheme q_scheme, + int64_t version) + : packed_w(std::move(packed_w)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + bit_rate_(bit_rate), + q_scheme(q_scheme), + version_(version) { + if (!this->packed_w.is_contiguous()) { + this->packed_w = this->packed_w.contiguous(); + } + } + + at::Tensor packed_w; + std::vector w_scale; + std::vector w_zp; + int64_t bit_rate_; + c10::QScheme q_scheme; + int64_t version_; + + at::Tensor unpack() override; + static c10::intrusive_ptr prepack( + at::Tensor weight); + + int64_t bit_rate() const override { + return bit_rate_; + } + + int64_t version() const override { + return version_; + } + + at::Tensor embeddingbag_byte( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; + + at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const std::optional& offsets, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h new file mode 100644 index 0000000000000000000000000000000000000000..96dd2b3b274f31015c6e738a077e11bb1d38f6a1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h @@ -0,0 +1,11 @@ +#pragma once + +#ifdef USE_PYTORCH_QNNPACK + +namespace at::native { + +void initQNNPACK(); + +} // namespace at::native + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qconv.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qconv.h new file mode 100644 index 0000000000000000000000000000000000000000..0be160aa3522e80dd378729e8bb787d3619f683e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qconv.h @@ -0,0 +1,100 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +class QConvoneDNN final { + public: + + C10_API static at::Tensor run_pointwise( + at::Tensor act, // contains quantized values but not QTensor + double act_scale, + int64_t act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view attr, + torch::List> scalars, + std::optional algorithm); + + C10_API static at::Tensor run_pointwise_tensor( + at::Tensor act, // contains quantized values but not QTensor + at::Tensor act_scale, + at::Tensor act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view attr, + torch::List> scalars, + std::optional algorithm); + + C10_API static at::Tensor run_pointwise_binary( + at::Tensor act, // contains quantized values but not QTensor + double act_scale, + int64_t act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + at::Tensor accum, // contains quantized values but not QTensor + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double accum_scale, + int64_t accum_zero_point, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + + C10_API static at::Tensor run_pointwise_binary_tensor( + at::Tensor act, // contains quantized values but not QTensor + at::Tensor act_scale, + at::Tensor act_zero_point, + at::Tensor weight, // contains quantized values but not QTensor + at::Tensor weight_scales, + at::Tensor weight_zero_points, + at::Tensor accum, // contains quantized values but not QTensor + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double accum_scale, + int64_t accum_zero_point, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm); + +}; + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h new file mode 100644 index 0000000000000000000000000000000000000000..a489c0dc3b387cfa4ec366a1d7d370e25c4cc38c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h @@ -0,0 +1,32 @@ +#pragma once +#include +#include + +namespace at::native { +Tensor& embedding_bag_byte_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const std::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset); + +Tensor& embedding_bag_4bit_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const std::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const std::optional& per_sample_weights_, + const std::optional& compressed_indices_mapping, + bool include_last_offset); + +Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h new file mode 100644 index 0000000000000000000000000000000000000000..e157405c107b81442392502d5ccb54d6c0c2cd1d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h @@ -0,0 +1,12 @@ +#pragma once +#include + +namespace at::native { + +Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); + +Tensor qembeddingbag_byte_prepack(const Tensor& weight); + +Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qlinear.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qlinear.h new file mode 100644 index 0000000000000000000000000000000000000000..60501c6b5373853fb723fe71d14bd4461fc6f6f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cpu/qlinear.h @@ -0,0 +1,51 @@ +#pragma once +#include +#include + +namespace at::native { + +class QLinearOnednn final { + public: + C10_API static Tensor run_pointwise_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view post_op_name, + c10::List> post_op_args, + std::string_view post_op_algorithm); + +C10_API static Tensor run_pointwise_binary_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional other, // extra input for binary post-op + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + std::string_view binary_post_op, // e.g. "none", "sum", "add" + double binary_alpha, + std::string_view unary_post_op, // e.g. "none", "relu" + c10::List> unary_post_op_args, + std::string_view unary_post_op_algorithm); +}; + +C10_API Tensor _weight_int4pack_mm_cpu_tensor( + const Tensor& A, + const Tensor& B, + const Tensor& qGroupSize, + const Tensor& qScaleAndZeros); + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cudnn/utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cudnn/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..824694d363a01bbe62cedcc781873f7a0eee243c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/cudnn/utils.h @@ -0,0 +1,318 @@ +#pragma once +/* +This file contains some of the auxiliary functions used by both Conv.cpp & Linear.cpp (introduced in a later PR) +*/ + +#ifdef USE_CUDA +#include // for the definition of AT_CUDNN_ENABLED + +#if AT_CUDNN_ENABLED() + +#include +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") +#include +C10_DIAGNOSTIC_POP() + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +struct PackedLinearWeightCudnn : public LinearPackedParamsBase { + PackedLinearWeightCudnn( + at::Tensor orig_weight, + std::optional bias, + c10::QScheme q_scheme) + : orig_weight(std::move(orig_weight)), + bias_(std::move(bias)), + q_scheme(std::move(q_scheme)) {} + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) override { + throw std::runtime_error( + "apply_dynamic is not implemented for this packed " + "parameter type"); + } + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) override { + throw std::runtime_error( + "apply_dynamic_relu is not implemented for this packed " + "parameter type"); + } + + std::tuple> unpack() override; + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + std::optional bias); + + private: + at::Tensor orig_weight; + std::optional bias_; + c10::QScheme q_scheme; + + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + + template + void apply_impl_helper( + const at::Tensor& quantized_output, + const at::Tensor& input, + double output_scale); +}; + +template +struct PackedConvWeightCudnn : public ConvPackedParamsBase { + PackedConvWeightCudnn( + at::Tensor orig_weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose, + c10::QScheme q_scheme, + int64_t output_channels) + : maybe_padded_weight_(std::move(orig_weight)), + bias_(std::move(bias)), + stride_(stride), + padding_(padding), + output_padding_(output_padding), + dilation_(dilation), + groups_(groups), + transpose_(transpose), + q_scheme_(q_scheme), + num_unpadded_output_channels_(output_channels) {} // output channels needs to be stored when we have to pad this dimension + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override { + TORCH_CHECK(false, "apply_dynamic is currently not reported"); + } + + at::Tensor apply_dynamic_relu( + const at::Tensor& input, + bool reduce_range) { + TORCH_CHECK(false, "apply_dynamic_relu is currently not reported"); + } + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + const float* GetBiasData(at::Tensor* bias); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return transpose_; + } + + private: + // cudnn v8.4.0 expects conv2d's int8 weight tensor's input and output channels to be a multiple of 4. if it is not + // we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding, hence the naming + // convention "maybe"_padded_weight. + // TODO: when and if cudnn enables padding in their operators, we can remove padding on our end and rename this to orig_weight_ + at::Tensor maybe_padded_weight_; + std::optional bias_; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + bool transpose_; + c10::QScheme q_scheme_; + int64_t num_unpadded_output_channels_; + + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + + template + void apply_impl_helper( + const at::Tensor& quantized_output, + const at::Tensor& input, + double output_scale); +}; + +namespace cudnn_utils { + +// TODO: we can remove this function when cuDNN enables pass by value support for +// pointwise multiplication operations. the only reason why we need this right now is +// we use broadcasting scalar multiplication in conv, linear, and add ops, and cuDNN requires +// the scalar to be a scalar tensor with the same number of dimensions (num_dim) as the tensor we're multiplying to +inline at::Tensor getRequantMultiplierTensor(double requant_multiplier, uint8_t num_dim) { + at::SmallVector requantize_multiplier_tensor_size(num_dim, 1); + at::Tensor requantize_multiplier_tensor = at::empty(requantize_multiplier_tensor_size, at::device(at::kCUDA).dtype(at::kFloat)); + requantize_multiplier_tensor.fill_(requant_multiplier); + return requantize_multiplier_tensor; +} + +inline uint8_t getAlignment(const at::Tensor &t) { + // alignment are in bytes + uint8_t alignment = 1; + uintptr_t address = reinterpret_cast(t.data_ptr()); + for (; alignment < 16; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +// For the two getTensorDescriptor functions, there is a is_virtual parameter. This parameter is used to set the cudnn +// tensor as virtual or not. Setting the tensor as virtual is expected to have some performance benefits as the cudnn +// backend cudnn will no longer directly save to the tensor, allowing us to omit this tensor from the variant pack. +// See third_party/cudnn_frontend/samples/fusion_sample.cpp for other examples + +inline cudnn_frontend::Tensor getTensorDescriptor(const at::Tensor &t, int64_t id, uint8_t alignment, bool is_virtual = false) { + auto shape = t.sizes(); + auto strides = t.strides(); + if (is_virtual) { + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setVirtual() + .setDataType(at::native::getCudnnDataType(t)) + .build(); + } + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setDataType(at::native::getCudnnDataType(t)) + .build(); +} + +inline cudnn_frontend::Tensor getTensorDescriptor(const c10::IntArrayRef& shape, const c10::IntArrayRef& strides, cudnnDataType_t cudnn_dtype, int64_t id, uint8_t alignment, bool is_virtual = false) { + if (is_virtual) { + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setVirtual() + .setDataType(cudnn_dtype) + .build(); + } + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(alignment) + .setDataType(cudnn_dtype) + .build(); +} + +// TODO: there is a table from input dtype to operator dtype, we can derive +// the operator dtype based on input dtype +inline cudnn_frontend::PointWiseDesc_v8 getPointWiseMulDescriptor(cudnnDataType_t dataType) { + return cudnn_frontend::PointWiseDescBuilder() + .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_MUL) + .setMathPrecision(dataType) + .build(); +} + +// TODO: there is a table from input dtype to operator dtype, we can derive +// the operator dtype based on input dtype +inline cudnn_frontend::PointWiseDesc_v8 getPointWiseAddDescriptor(cudnnDataType_t dataType) { + return cudnn_frontend::PointWiseDescBuilder() + .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_ADD) + .setMathPrecision(dataType) + .build(); +} + +// TODO: there is a table from input dtype to operator dtype, we can derive +// the operator dtype based on input dtype +inline cudnn_frontend::PointWiseDesc_v8 getPointWiseReluDescriptor(cudnnDataType_t dataType) { + return cudnn_frontend::PointWiseDescBuilder() + .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(dataType) + .build(); +} + + +inline void filterEngineConfigs( + cudnn_frontend::EngineConfigList &from, + cudnn_frontend::EngineConfigList &to, + bool deterministic, bool allow_tf32, c10::ScalarType scalar_type) +{ + auto filter = [=](cudnnBackendDescriptor_t c) { + if (deterministic) { + if (cudnn_frontend::hasNumericalNote(c)) return true; + } + if (scalar_type == at::kFloat || scalar_type == at::kChar || !allow_tf32) { + if (cudnn_frontend::hasNumericalNote(c)) return true; + if (cudnn_frontend::hasNumericalNote(c)) return true; + } + return false; + }; + cudnn_frontend::filter(from, to, filter); +} + +} // cudnn_utils + +#endif // AT_CUDNN_ENABLED +#endif // USE_CUDA diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/library.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/library.h new file mode 100644 index 0000000000000000000000000000000000000000..09fa2f626603567e50ac305c03a53326ce4a0879 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/quantized/library.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +TORCH_API int register_linear_params(); +int register_embedding_params(); + +template TORCH_API int register_conv_params(); diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/attention.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/attention.h new file mode 100644 index 0000000000000000000000000000000000000000..7bb75f00acff497a0edea81d962224c059496565 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/attention.h @@ -0,0 +1,70 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at::native { + +using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value, + const std::optional& attn_mask_, double dropout_p, bool is_causal, std::optional scale, bool enable_gqa); + +DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub) + +TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b); +TORCH_API Tensor masked_softmax( + Tensor& attn_scores, + std::optional attn_mask, + const Tensor& query, + std::optional mask_type = {}); + +using transform_bias_rescale_qkv_fn = void(*)( + at::ScalarType type, + void* _q_k_v, + const void* _qkv, + const void* _qkv_bias, + int64_t B, + int64_t T, + int64_t D, + int64_t num_head); + +DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub) + +TORCH_API Tensor transform0213_gemm_nt_bias( + const Tensor& a, + const Tensor& b, + const Tensor& c, + const Tensor& query); + +TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b); + +TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape); + +TORCH_API Tensor qkv_projection( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const int64_t embed_dim, + const Tensor& qkv_weight); + +using flash_attention_fn = void (*)( + const Tensor& output, const Tensor& logsumexp, + const Tensor& query, const Tensor& key, const Tensor& value, + double dropout_p, bool is_causal, + std::optional attn_mask, + std::optional scale); + +using flash_attention_backward_fn = void (*)( + const Tensor& grad_q, const Tensor& grad_k, + const Tensor& grad_v, const Tensor& grad_out, + const Tensor& query, const Tensor& key, + const Tensor& value, const Tensor& out, const Tensor& logsumexp, + double dropout_p, bool is_causal, + std::optional attn_mask, + std::optional scale); + +DECLARE_DISPATCH(flash_attention_fn, flash_attention_kernel) +DECLARE_DISPATCH(flash_attention_backward_fn, flash_attention_backward_kernel) + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/flash_api.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/flash_api.h new file mode 100644 index 0000000000000000000000000000000000000000..f5ba2c117d99be6d9c165f02ba2200625186bfaa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/flash_api.h @@ -0,0 +1,96 @@ +#pragma once +#include + +#include +#include +#include + +namespace FLASH_NAMESPACE { + +TORCH_API +std::tuple +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + + +std::tuple +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +} // namespace FLASH_NAMESPACE diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/static_switch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..ca12fa171bf989c3f7158d85811d26d346c27dfb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/flash_attn/static_switch.h @@ -0,0 +1,107 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7b91889012e4ebb42a3819f828e61a7fe1f482bc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (int _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_B0_T0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_B0_T0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) + return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_B0_T0( \ + "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ + for (int _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_B0_T0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_B0_T0( \ + "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..70320a599c4ab7f00822b64bddd3d930554f08bb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,631 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + using SourceAccessType = Array< + typename OutputTileSourceIterator::Element, + OutputTileSourceIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + constexpr int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000000000000000000000000000000000..fd7982e5f699e3ab66171a19781c12460cbde1ba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,238 @@ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to store tensors + typename ElementSource_, //< Data type for source (usually matches + //`ElementOutput`) + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template < + typename EO, + typename ES, + int Count, + typename EA, + typename EC, + bool F, + bool L, + typename FAB, + FloatRoundStyle R> +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000000000000000000000000000000000..e3dc0778e46b14d1aa64d0e628ebaa3aa5745904 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,175 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template < + typename ElementOutput_, // output + typename ElementLSE_, // accumulator from LSE + typename ElementAccumulator_, // accumulator from matmul + typename ElementCompute_, // intermediate compute (and exp calculation) + int ElementsPerAccess> +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter< + ElementCompute, + ElementAccumulator, + kElementsPerAccess>()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter< + ElementOutput, + ElementCompute, + kElementsPerAccess>()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..b476742f9f2745266e5e9c2c64759d7833d44829 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#include +#include +template +struct MakeCustomMma; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + cutlass::arch::CacheOperation::Kind CacheOpA, + typename IteratorB, + typename SmemIteratorB, + cutlass::arch::CacheOperation::Kind CacheOpB, + typename ElementC, + typename LayoutC, + typename Policy, + int Stages, + cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + Stages, + SharedMemoryClear>, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min( + Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + kStages, + SharedMemoryClear, + kMaxK>; +}; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + typename IteratorB, + typename SmemIteratorB, + typename ElementC, + typename LayoutC, + typename Policy, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>; +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..229c59d68347a043001842472c50458037b562e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef{buffer.data(), Layout()}; + } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape< + Shape::kM + Policy::SmemPaddingA::kRow, + Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + using SharedStorageA = OperandSharedStorage< + typename Operator::ElementA, + ShapeA, + typename Operator::LayoutA>; + using SharedStorageB = OperandSharedStorage< + typename Operator::ElementB, + ShapeB, + typename Operator::LayoutB>; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..91f7597b6ba356d11c591f6eb73e97ec070dadf1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h @@ -0,0 +1,768 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireMat ? Stages : Stages - 1; + + private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + void set_zero_outside_bounds(bool value) { + zero_outside_bounds_ = value; + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue( + iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index( + group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index( + group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue( + IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform( + warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || + gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load( + warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform( + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && + warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && + smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..6c2ed2dd8ee3dda76422aed612d3bb233fb1aeac --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h @@ -0,0 +1,402 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // statically assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + void set_zero_outside_bounds(bool value) { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = + TransformA(), ///< transformation applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..c1b93bb06ba6d9b5c1d030f5cb885dedf940efed --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h @@ -0,0 +1,167 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +/*! \file + \brief Cutlass provides helper template functions to figure out the right + data structures to instantiate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our use case. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + AccumulatorsInRowMajor, + SharedMemoryClear>; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template < + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + int kStages, + typename Operator> +struct FindDefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 3, + Operator>; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, + ElementAccumulator, + LayoutC, + typename MmaCore_::MmaPolicy, + kStages>; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..afc3c9b7142e6e4a050915e8e9470b320ab2bb9b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,354 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert( + cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert( + cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template < + typename S1, + typename S2, + typename S3, + typename accum_t, + int kWarpSize> +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..df2c92326f5876e8dd4c422ea8a0eba2169460b1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h @@ -0,0 +1,1948 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template < + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_> +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass:: + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum K dimension - also the dimension of the shared-memory + // holding `OperandA` + int kMaxK_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Layout in shared-memory of operand A + typename SmemLayoutA, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + static constexpr int kMaxK = kMaxK_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + TensorRefB& b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(b_tile, lane_idx) {} +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // Since arrays of zero-sized objects are not allowed, using size as 1. + // The compiler will most likely wipe it out anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Max GEMM problem size in K dimension + int MaxK, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout>; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + using WarpIteratorA = WarpIteratorA_; + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // statically assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, // Operand A in shared memory + typename Base::TensorRefA a_scale, // Operand A_scale in shared memory + typename Base::TensorRefB + b_staging, // staging memory for loading tiles of B + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + warp_tile_iterator_A_scale_(a_scale, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, ///< Operand A in shared memory + typename Base::TensorRefB b_staging, ///< staging memory for loading B + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx) ///< ID of each thread within a warp + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async transfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + if (gemm_k_iterations > 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + } + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma( + accum, + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout>; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefA a_scale, + typename Base::TensorRefB b_tile, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + warp_tile_iterator_A1_scale_(a_scale, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefB b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue( + iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index( + group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue( + IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma1( + tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template < + typename Mma_, + int kMaxK, + typename WarpIteratorA_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA = false> +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + // Max MMA problem size K + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = WarpIteratorA_; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + kMaxK, + IteratorB, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform::conditional< + kIsTransposedA, + typename WarpIteratorTranspose::Iterator, + WarpIteratorA_>::type; + + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename IteratorC, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + IteratorAccumulatorLSE, + EpilogueOpApplyLSE>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue( + minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using SmemAccumulatorLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + SmemAccumulatorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast( + ref_.data() + ref_.offset({r, c})) = to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template < + typename Operator, + typename OperatorPolicy, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..14efce61920c4fb8ecf938343cb6905777d1159a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h @@ -0,0 +1,209 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + TORCH_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE T warp_uniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..10f1ddf34aff91d24aaa1570b799464b18ec5740 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Instantiates the right WarpIterator to read from shared memory + The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading + data dumped with `B2bGemm::accumToSmem`. +*/ + +#pragma once + +#include +#include +#include + +#include + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, kInstrK>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::MatrixShape>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..c7a9915fed6d8c78681994113531e8300856df1e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,752 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false> +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( + ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert( + ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert( + ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert( + ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert( + sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch( + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < + extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) + row_add_P = 0; + if (output_Q > convolution_Q - 2) + row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { + mask_ = mask; + } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch< + typename IT::ThreadMap, + typename IT::Element>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..119389deea2bf8c0110f58fb9b6bfb1d7f1b3a3e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..a97737bfe2a1e8d09ba249d52544df4b242e5c5b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,2115 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather = false> +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base( + layout.stride(0), + MakePredicatedTileAccessIteratorDesc< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap>()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / + 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * sizeof_bits::value / 8; + + return reinterpret_cast( + pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + layout::PitchLinear, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..00e1b0ac2001a2c331b1c829b7daed978f19dd3c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,2120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Access out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false> +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..85d043250aac9ef29e2eee5d68921d4b88dae69c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + typename InstructionShape, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp:: + WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp:: + WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..42609d7431feaa5494eef7f7df04da40f6925340 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + typename InstructionShape_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + static_assert( + kOperand == Operand::kA, + "No support for OperandB at the moment"); + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); + static_assert( + InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, + "Only supports 16x8x8 / 16x8x16"); + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + // Number of 32bits tiles to load per `ldmatrix` + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + // See also: + // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 + // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) + // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kTilesPerInstruction == 4, + "can't use ldmatrix.x4"); + int access_m_idx = ldsm_vec_num % kTilesPerInstruction; + int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; + int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } else { + // XXX: This is not tested or used + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < + (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; + ++access_m_idx) { + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord( + access_m_idx * 16, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..ae649e99c4cd8047d094e8e9fdf54e2a95affb74 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -0,0 +1,2614 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +using namespace gemm_kernel_utils; + +namespace PyTorchMemEffAttention { +namespace { + +template +struct GmemTile { + /* + Helper functions to efficient store/load RF to gmem + + GEMM accumulators have a particular format on A100, and + it takes some compute/shared-memory to rearrange them to + a RowMajor or ColumnMajor format in global memory through + an Epilogue. The same complexity goes for loading into RF. + + This class loads/stores RF as they are, and can be used for + efficient accumulation across gemms for instance: + + ``` + GmemTile tile; + for (int i = 0; i < N; ++i) { + // ... + + Fragment accum; + if (i == 0) { + accum.clear(); + } else { + tile.load(accum); + } + mma(accum, ...); + if (i < N-1) { + // Store for next GEMM + tile.store(accum); + } else { + // Store in tensor (eg RowMajor) + epilogue(accum); + } + + // ... + } + ``` + */ + + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = + FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = + kNumThreads * FragmentType::kElements; + static_assert( + FragmentType::kElements % AccessType::kElements == 0, + "fragment not aligned on 128 bits"); + + float* ptr; + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load( + sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store( + sub_fragment, gmem_ptr, true); + } + } + + CUTLASS_DEVICE void storeAtomicAdd( + FragmentType const& fragment, + int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + float val = fragment[i * AccessType::kElements + j]; + float* ptr = gmem_ptr + j; + atomicAdd(ptr, val); + } + } + } +}; + +struct AtomicLock { + CUTLASS_DEVICE static void acquire( + int32_t* lock, + int set_val, + int thread_id) { + if (thread_id == 0) { + while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + __nanosleep(40); +#endif + } + } + __syncthreads(); + } + CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) { + if (thread_id == 0) { + int status = 0; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.gpu.b32 [%0], %1;\n" + : + : "l"(lock), "r"(status)); +#else + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#endif + } + } +}; + +template +constexpr int getWarpsPerSmBw() { + bool is_half = !cutlass::platform::is_same::value; + if (Arch::kMinComputeCapability >= 80) { + return is_half ? 12 : 8; + } + return 8; +} +} // namespace + +template < + // which arch we target (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // input/output type + typename scalar_t_, + // run optimized kernel because memory accesses will be aligned + bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout_, + // when doing a GEMM, preload the next one (uses more shmem) + bool kPreload_, + // block dimensions + int kBlockSizeI_, + int kBlockSizeJ_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // assumes that `cu_seqlen` is None, and + // (1) `num_queries % kBlockSizeI == 0` + // (2) `num_keys % kBlockSizeJ == 0` + bool kKeysQueriesAlignedToBlockSize_ = false> +struct AttentionBackwardKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + using scalar_t = scalar_t_; + using output_t = scalar_t; + using output_accum_t = float; + using lse_scalar_t = float; + using accum_t = float; + using ArchTag = ArchTag_; + static constexpr bool kIsAligned = kIsAligned_; + static constexpr bool kApplyDropout = kApplyDropout_; + static constexpr bool kPreload = kPreload_; + static constexpr int kBlockSizeI = kBlockSizeI_; + static constexpr int kBlockSizeJ = kBlockSizeJ_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kKeysQueriesAlignedToBlockSize = + kKeysQueriesAlignedToBlockSize_; + + static constexpr int64_t kWarpSize = 32; + + // If this is true, we store and accumulate dK/dV in RF + // rather than going back to gmem everytime + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static_assert( + !kPreload || + (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), + "preload MMA not supported"); + static constexpr bool kPrologueQK = kPreload; + static constexpr bool kPrologueGV = kPreload; + static constexpr bool kPrologueDOV = kPreload; + static constexpr bool kPrologueGQ = kPreload; + static constexpr bool kPrologueGK = kPreload; + + static constexpr int64_t kNumWarpsPerBlock = + (kBlockSizeI * kBlockSizeJ) / (32 * 32); + + // Compute delta for the f16 kernels + // TODO: Figure out why it's slower on the f32 kernels + // (something due to RF pressure?) + // TODO: Remove condition on `kOutputInRF` - this is needed to work + // around a compiler bug on V100, not exactly sure why but I spent + // too much time on this already. Reproducible with + // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance + static constexpr bool kKernelComputesDelta = + kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + + // Launch bounds + static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int64_t kMinBlocksPerSm = + getWarpsPerSmBw() / kNumWarpsPerBlock; + + using GemmType = DefaultGemmType; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + typename GemmType::OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr auto kOptimalAlignement = cutlass::platform::max( + DefaultConfig::kAlignmentA, + DefaultConfig::kAlignmentB); + static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; + + struct MatmulQK { + /* + attn_T = k_j @ q_i.transpose(-2, -1) # matmul + attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, + -1)).exp() # epilogue + + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + accum_t, // ElementC + cutlass::layout::RowMajor, // LayoutC + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + DefaultConfig::kStages, + typename GemmType::Operator, + false, // AccumulatorsInRowMajor = false, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using MmaCore = typename DefaultMma::MmaCore; + using Mma = + typename MakeCustomMma::Mma; + + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = TileSmemLoader< + scalar_t, + // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded + // row-major but needs to have transposed shape so we get the same + // elements. + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradV { + /* + grad_v[j_start:j_end] += attn_T @ do_i # matmul + + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) + (we might need to iterate multiple times on K) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Operator:: + InstructionShape, // InstructionShape + typename DefaultGemm::Mma::Operator:: + IteratorA, // RegularWarpIterator + typename DefaultGemm::Mma::Policy // Policy + >::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + kApplyDropout>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + struct MatmulDOIVJ { + /* + doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul + tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + + using ElementC = output_t; + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + ElementC, // ElementC + cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial + typename GemmType::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using Mma = typename MakeCustomMma::Mma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccum, + kWarpSize>::Iterator; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename DefaultGemm::Mma::Operator::IteratorC, + typename DefaultGemm::Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradQ { + // grad_q <- tmp @ k_j + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + struct MatmulGradK { + // grad_k <- tmp.transpose(-2, -1) @ q_i + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmemN = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using DefaultMmaFromSmemT = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK + WarpIteratorA, + false, // kScaleOperandA + kPreload>; // kTransposeA + using DefaultMmaFromSmem = typename cutlass::platform::conditional< + DefaultMmaFromSmemT::kIsTransposedA, + DefaultMmaFromSmemT, + DefaultMmaFromSmemN>::type; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + // NOTE: nvcc 12.4 has correctness errors with this on M60 (sm52) + // when there is an attention bias. Let's just disable it for now. + static constexpr auto kMinSm = ArchTag::kMinComputeCapability; + static constexpr bool kEnableSplitKeys = kMinSm >= 70; + + static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = !kOutputInRF && + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = !kOutputInRF && + !cutlass::platform::is_same::value; + + struct GradQTempStorage { + int32_t lock; + int32_t counter; + int32_t pad[2]; // pad to 128bits + output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored]; + }; + + struct Params { + // Input tensors + const scalar_t* query_ptr = nullptr; // [Mq, nH, K] + const scalar_t* key_ptr = nullptr; // [Mk, nH, K] + const scalar_t* value_ptr = nullptr; // [Mk, nH, Kv] + const scalar_t* bias_ptr = nullptr; + const lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq] + const scalar_t* output_ptr = nullptr; // [Mq, nH, Kv] + const scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv] + accum_t* delta_ptr = nullptr; // [nH, Mq] + const int32_t* cu_seqlens_q_ptr = nullptr; + const int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr = nullptr; // [Mq, nH, K] + output_t* grad_key_ptr = nullptr; // [Mk, nH, K] + output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv] + output_t* grad_bias_ptr = nullptr; + + // Accumulators + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gv = + nullptr; // (will be calculated by the kernel) + GradQTempStorage* workspace_gq = + nullptr; // (will be calculated by the kernel) + + // Sliding window. ignored if == 0 + int32_t window_size = 0; + + // Scale + accum_t scale = 1.0f; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + uint8_t custom_mask_type = NoCustomMask; + + int64_t q_strideM = -1; + int64_t k_strideM = -1; + int64_t v_strideM = -1; + int64_t bias_strideM = 0; + int64_t gO_strideM = -1; + int64_t gB_strideM = -1; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + + at::PhiloxCudaState rng_engine_inputs = {0, 0}; + + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int64_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int64_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int64_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int64_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH = -1; + int32_t q_strideH = -1; + int32_t k_strideH = -1; + int32_t v_strideH = -1; + int64_t bias_strideH = 0; + int64_t o_strideB = -1; + int64_t q_strideB = -1; + int64_t k_strideB = -1; + int64_t v_strideB = -1; + int64_t bias_strideB = 0; + int64_t lse_strideB = -1; + int64_t lse_strideH = -1; + int64_t delta_strideB = -1; + int64_t delta_strideH = -1; + int32_t num_batches = -1; + int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_HOST_DEVICE int16_t num_splits_key_device() const { +#ifdef __CUDA_ARCH__ + return kEnableSplitKeys ? gridDim.x : 1; +#else + return num_splits_key; // for host-side tests +#endif + } + CUTLASS_HOST_DEVICE int16_t split_key_device() const { +#ifdef __CUDA_ARCH__ + return kEnableSplitKeys ? blockIdx.x : 0; +#else + return 0; // for host-side tests +#endif + } + + CUTLASS_DEVICE bool advance_to_block() { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = + (GradQTempStorage*)(workspace_gv + workspace_elements_gv()); + if (kEnableSplitKeys) { + workspace_gv += workspace_elements_gv() * split_key_device() / + num_splits_key_device(); + workspace += workspace_elements_gk() * split_key_device() / + num_splits_key_device(); + } + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = + batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + int32_t q_start = cu_seqlens_q_ptr[0]; + int32_t k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + assert(q_next_start - q_start <= num_queries); + assert(k_next_start - k_start <= num_keys); + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + // Jump manually + batch_id = 0; + + query_ptr += q_start * q_strideM; + key_ptr += k_start * k_strideM; + value_ptr += k_start * v_strideM; + assert(bias_ptr == nullptr); + assert(grad_bias_ptr == nullptr); + output_ptr += q_start * o_strideM(); + grad_output_ptr += q_start * gO_strideM; + delta_ptr += q_start; + + grad_query_ptr += q_start * gQ_strideM(); + grad_key_ptr += k_start * gK_strideM(); + grad_value_ptr += k_start * gV_strideM(); + } + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + if (bias_ptr != nullptr) { + bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; + } + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + if (grad_bias_ptr != nullptr) { + grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; + } + + // Some values are modified above + // Signal to the compiler that they are the same in all threads + // and can be stored in warp-uniform registers (Sm75+) + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + custom_mask_type = warp_uniform(custom_mask_type); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + bias_ptr = warp_uniform(bias_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + grad_bias_ptr = warp_uniform(grad_bias_ptr); + +#if 0 + PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", + int(blockIdx.z), int(blockIdx.y), + float(delta_ptr[0]), + float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), + float(logsumexp_ptr[0]) + ) +#endif + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3(num_splits_key, num_heads, num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { + if (!kNeedsAccumGradK) { + return 0; + } + return num_splits_key * kBlockSizeJ * + align_up(head_dim, kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { + if (!kNeedsAccumGradV) { + return 0; + } + return num_splits_key * kBlockSizeJ * + align_up(head_dim_value, kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { + if (!kNeedsAccumGradQ) { + return 0; + } + int num_blocks = ceil_div(num_queries, kBlockSizeI); + int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN); + return num_blocks * num_cols * sizeof(GradQTempStorage) / + sizeof(output_accum_t); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + CUTLASS_HOST_DEVICE bool should_zero_workspace() const { + return num_splits_key > 1 || window_size > 0; + } + }; + + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + + struct SharedStoragePrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + } persistent; + union { + struct { + // part1 - after Q.K / dV / dO.V + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi + ZijSharedStorage zij; + + union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi + typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + } part1; + + struct { + // part2 - dQ + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; + + } part2; + + struct { + // part3 - after last iteration on dQ's epilogue / dK + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + } part3; + + struct { + // part4 - after last iteration on dK's epilogue / preload next K.Q_t + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + + // If we reach end of current key, dump RF->gmem with "final" epilogues + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part4; + }; + static void print_size() { + // Field size +#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f))) + + printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue))); + printf(" persistent: %db\n", FSZ(persistent)); + printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); + printf(" part1: %db\n", FSZ(part1)); + printf(" bias: %db\n", FSZ(part1.bias)); + printf(" attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage)); + printf(" zij: %db\n", FSZ(part1.zij)); + printf(" mm_gradV: %db\n", FSZ(part1.mm_gradV)); + printf(" gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue)); + printf(" mm_doivj: %db\n", FSZ(part1.mm_doivj)); + printf(" part2: %db\n", FSZ(part2)); + printf(" tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage)); + printf(" tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage)); + printf(" mm_gradK: %db\n", FSZ(part2.mm_gradK)); + printf(" mm_gradQ: %db\n", FSZ(part2.mm_gradQ)); + printf(" gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue)); + printf(" gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue)); + printf(" part3: %db\n", FSZ(part3)); + printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage)); + printf(" part4: %db\n", FSZ(part4)); + printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q)); + printf( + " gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); + printf( + " gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(persistent, mm_qk_k) + FIELD(part1, bias) + FIELD(part1, attn_shared_storage) + FIELD(part1, zij) + FIELD(part1, mm_gradV) + FIELD(part1, gradV_epilogue) + FIELD(part1, mm_doivj) + FIELD(part2, mm_gradK) + FIELD(part2, mm_gradQ) + FIELD(part2, gradB_epilogue) + FIELD(part2, gradQ_epilogue) + FIELD(part2, tmp_shared_storage) + FIELD(part3, tmpT_shared_storage) + FIELD(part3, gradQ_epilogue_lastIter) + FIELD(part3, gradK_epilogue) + FIELD(part4, mm_qk_q) + FIELD(part4, gradK_epilogue_final) + FIELD(part4, gradV_epilogue_final) + }; + + struct SharedStorageNoPrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + } persistent; + union { + struct { + // part1 - Q.K matmul + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + } part1; + + struct { + // part2 - compute gradV + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed in this step, where it is used + // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are + // loaded for the computation of dVj. + ZijSharedStorage zij; + + union { + typename MatmulGradV::Mma::SharedStorage mm_gradV; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + } part2; + + struct { + // part3 - DO.V matmul + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from part2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; + } part3; + + struct { + // part4 - compute gradQ + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + }; + } part4; + + struct { + // part5 - compute gradK + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradK::Mma::SharedStorage mm_gradK; + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + }; + } part5; + + struct { + // part6 - store RF accumulated into gmem + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part6; + }; + static void print_size() { +#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f))) + printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue))); + printf(" persistent: %db\n", FIELD_SIZEOF(persistent)); + printf(" part1: %db\n", FIELD_SIZEOF(part1)); + printf(" part2: %db\n", FIELD_SIZEOF(part2)); + printf(" part3: %db\n", FIELD_SIZEOF(part3)); + printf(" part4: %db\n", FIELD_SIZEOF(part4)); + printf(" part5: %db\n", FIELD_SIZEOF(part5)); + printf(" part6: %db\n", FIELD_SIZEOF(part6)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(part1, mm_qk_k) + FIELD(part1, mm_qk_q) + FIELD(part2, bias) + FIELD(part2, attn_shared_storage) + FIELD(part2, zij) + FIELD(part2, mm_gradV) + FIELD(part2, gradV_epilogue) + FIELD(part3, mm_doivj) + FIELD(part3, gradB_epilogue) + FIELD(part4, tmpT_shared_storage) + FIELD(part4, tmp_shared_storage) + FIELD(part4, mm_gradQ) + FIELD(part4, gradQ_epilogue) + FIELD(part4, gradQ_epilogue_lastIter) + FIELD(part5, mm_gradK) + FIELD(part5, gradK_epilogue) + FIELD(part6, gradK_epilogue_final) + FIELD(part6, gradV_epilogue_final) + }; + + using SharedStorage = typename cutlass::platform::conditional< + kPreload, + SharedStoragePrologue, + SharedStorageNoPrologue>::type; + + struct OutputFragments { + typename MatmulGradV::Mma::FragmentC gradV; + typename MatmulGradK::Mma::FragmentC gradK; + + CUTLASS_DEVICE void clear() { + gradV.clear(); + gradK.clear(); + } + }; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); + TORCH_CHECK( + p.num_heads <= 1 || p.lse_strideH % 8 == 0, + "LSE is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_batches <= 1 || p.lse_strideB % 8 == 0, + "LSE is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, + "query is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, + "key is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, + "value is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, + "query is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, + "key is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, + "value is not correctly aligned (strideB)"); + TORCH_CHECK( + p.q_strideM % kMinimumAlignment == 0, + "query is not correctly aligned (strideM)"); + TORCH_CHECK( + p.k_strideM % kMinimumAlignment == 0, + "key is not correctly aligned (strideM)"); + TORCH_CHECK( + p.v_strideM % kMinimumAlignment == 0, + "value is not correctly aligned (strideM)"); + if (p.bias_ptr) { + TORCH_CHECK( + p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kMinimumAlignment, "."); + TORCH_CHECK( + p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideH) ." + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kMinimumAlignment, "."); + TORCH_CHECK( + p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ", + "multiple of ", kMinimumAlignment, "."); + } + if (p.grad_bias_ptr) { + TORCH_CHECK( + p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideB)"); + TORCH_CHECK( + p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideH)"); + TORCH_CHECK( + p.gB_strideM % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideM)"); + } + TORCH_CHECK( + !(p.cu_seqlens_q_ptr && p.bias_ptr), + "CuSeqlen + bias not implemented yet"); + TORCH_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "Invalid value for `custom_mask_type`"); + TORCH_CHECK( + p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, + "Invalid value for `dropout_prob`"); + TORCH_CHECK( + kApplyDropout || p.dropout_prob == 0.0f, + "Set `kApplyDropout`=True to support `dropout_prob > 0`"); + TORCH_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); + TORCH_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); + TORCH_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); + TORCH_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); + TORCH_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); + TORCH_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); + TORCH_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); + TORCH_CHECK( + p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + if (kKeysQueriesAlignedToBlockSize) { + TORCH_CHECK( + p.cu_seqlens_k_ptr == nullptr, + "This kernel does not support cu_seqlen"); + TORCH_CHECK( + p.cu_seqlens_q_ptr == nullptr, + "This kernel does not support cu_seqlen"); + TORCH_CHECK( + p.num_queries % kBlockSizeI == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + TORCH_CHECK( + p.num_keys % kBlockSizeJ == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + } + TORCH_CHECK( + kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); + TORCH_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + TORCH_CHECK( + p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), + "Invalid `num_splits_key` (", + p.num_splits_key, + ") - too large for `num_keys` = ", + p.num_keys); + if (p.window_size != 0) { + TORCH_CHECK( + p.custom_mask_type != NoCustomMask, + "LocalAttention only supported in causal mode"); + } + return true; + } + + static CUTLASS_DEVICE void attention_kernel(Params p) { + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + int64_t key_start = p.split_key_device() * kBlockSizeJ; + if (key_start >= p.num_keys) { + return; + } + if (kPrologueQK) { + int64_t query_start = getQueryStart(p, key_start); + prologueQkNextIteration( + shared_storage, p, query_start, key_start, warp_id, lane_id); + } + + // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` + if (kKernelComputesDelta) { + constexpr int kOptimalElements = + 128 / cutlass::sizeof_bits::value; + if (p.head_dim_value % kOptimalElements == 0) { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta(p, query_start, warp_id, lane_id); + } + } else { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta<1>(p, query_start, warp_id, lane_id); + } + } + __syncthreads(); + } + + OutputFragments output_frags; + + curandStatePhilox4_32_10_t rng_state_init; + + if (kApplyDropout) { + // See Note [Seed and Offset Device] + auto const [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs); + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + seed, + 0, + offset + p.dropout_batch_head_rng_offset, + &rng_state_init); + } + + CUTLASS_PRAGMA_UNROLL + for (; key_start < p.num_keys; + key_start += p.num_splits_key_device() * kBlockSizeJ) { + output_frags.clear(); + + int64_t next_key = key_start; + int64_t query_start = getQueryStart(p, key_start); + while (next_key == key_start && query_start < p.num_queries) { + // This line here + // vvvvvvvvvvvvvv + warp_id = warp_uniform(warp_id); + // ^^^^^^^^^^^^^^ + // ... makes everything use less RF and be 10% faster. Why? + // I don't know. My theory is that it forces `nvcc` to + // re-compute indices, offsets etc... and not keep them + // from the previous iteration, which prevents MASSIVE + // register spilling. + + processBlockIJ( + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init, + warp_id, + lane_id); + + int64_t next_query; + incrIteration(p, query_start, key_start, next_query, next_key); + query_start = next_query; + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV( + p, key_start, warp_id, lane_id); + } + __syncthreads(); + } + } + + template + static CUTLASS_DEVICE void zfillGradKV( + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + constexpr int kThreadsPerKey = 8; + constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; + static_assert(kBlockSizeJ % kParallelKeys == 0, ""); + // This function is not really optimized, but should rarely be used + // It's only used when some keys are "useless" and don't attend to + // any query, due to causal masking + + int thread_id = 32 * warp_id + lane_id; + int k_shift = lane_id % kThreadsPerKey; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { + int key = key_start + j + (thread_id / kThreadsPerKey); + if (!skipBoundsChecks && key >= p.num_keys) { + continue; + } + auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); + auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); + + for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { + gv_ptr[k] = scalar_t(0); + } + for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { + gk_ptr[k] = scalar_t(0); + } + } + } + + template + static CUTLASS_DEVICE void processBlockIJ( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params& p, + int64_t query_start, + int64_t key_start, + const curandStatePhilox4_32_10_t& curand_state_init, + uint8_t warp_id, + uint8_t lane_id) { + cutlass::Array + dropout_keep_mask_doivj; + dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1}); + const float dropout_scale = + kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f; + + cutlass::MatrixCoord no_offset{0, 0}; + accum_t scale = p.scale; + int16_t thread_id = 32 * warp_id + lane_id; + + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; + + bool isFirstQuery = (query_start == getQueryStart(p, key_start)); + int64_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + bool isLastQuery = next_key != key_start; + + accum_t di_rf = accum_t(0); + if (thread_id < kBlockSizeI) { + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + shared_storage.di()[thread_id] = di_rf; + } + + int32_t num_queries_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kN + : warp_uniform(cutlass::fast_min( + MatmulQK::Mma::Shape::kN, (int32_t)(p.num_queries - query_start))); + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : warp_uniform(cutlass::fast_min( + MatmulQK::Mma::Shape::kM, (int32_t)(p.num_keys - key_start))); + + auto prologueGradV = [&](int64_t col) { + typename MatmulGradV::Mma::IteratorB iterator_dO( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM + col), + {num_queries_in_block, (int32_t)(p.head_dim_value - col)}, + thread_id, + no_offset); + MatmulGradV::Mma::prologue( + shared_storage.mm_gradV(), + iterator_dO, + thread_id, + num_queries_in_block); + }; + auto prologueGradQ = [&](int col) { + typename MatmulGradQ::Mma::IteratorB iterator_K( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM + col), + {num_keys_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradQ::Mma::prologue( + shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); + }; + auto prologueGradK = [&](int col) { + typename MatmulGradK::Mma::IteratorB iterator_Q( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM + col), + {num_queries_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradK::Mma::prologue( + shared_storage.mm_gradK(), + iterator_Q, + thread_id, + num_queries_in_block); + }; + auto prologueDOV = [&]() { + typename MatmulDOIVJ::Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM), + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + typename MatmulDOIVJ::Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + const_cast(p.value_ptr + key_start * p.v_strideM), + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + MatmulDOIVJ::Mma::prologue( + shared_storage.mm_doivj(), + iterator_A, + iterator_B, + thread_id, + p.head_dim_value); + }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulQK + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulQK::Mma; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + num_queries_in_block, + p.head_dim // k + ); + + // k_j + typename Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM), + {problem_size.m(), problem_size.k()}, + thread_id, + no_offset); + + // q_i.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM), + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + Mma mma( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma.set_prologue_done(kPrologueQK); + mma.set_zero_outside_bounds(!skipBoundsChecks); + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + accum = cutlass::multiplies()(scale, accum); + + // Epilogue: add LSE + exp and store that to our shared memory buffer + // shmem <- (matmul_result - + // logsumexp[i_start:i_end].unsqueeze(1)).exp() + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + // apply bias if applicable + if (p.bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MatmulQK::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + const_cast(p.bias_ptr + query_start * p.bias_strideM + key_start), + {num_queries_in_block, num_keys_in_block}, + thread_id); + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id); + MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, where Pij is in register fragment and Bij is in shmem + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); + }, + [&](int accum_n) {}); + } + + // Apply mask + if (p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight) { + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + int shift = query_start - key_start; + if (p.custom_mask_type == CausalFromBottomRight) { + shift += p.num_keys - p.num_queries; + } + // current_key = key_start + accum_m + // current_query = query_start + accum_n + // mask if: `current_key > current_query` + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m > accum_n + shift) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + if (p.window_size > 0) { + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + int shift = query_start - key_start - p.window_size; + // current_key = key_start + accum_m + // current_query = query_start + accum_n + // mask if: `current_key < current_query - window_size` + // if accum_m < accum_n + query_start - window_size - key_start + + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m <= accum_n + shift) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + __syncthreads(); + if (kPrologueGV) { + prologueGradV(0); + } + if (kPrologueDOV) { + prologueDOV(); + } + + MatmulQK::B2bGemm::accumApplyLSEToSmem( + shared_storage.attn_shared_storage(), + accum, + p.logsumexp_ptr + query_start, + problem_size.n(), + thread_id, + warp_id, + lane_id, + output_tile_coords); +#if 0 + auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref(); + PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT); +#endif + + // if we are using dropout, compute Zij, writing it to shared memory. + // each element of Zij is: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + if (kApplyDropout) { + auto zij = shared_storage.zij().accum_ref(); + // each thread generates a contiguous sequence of elements in Zij, all + // in the same row. the reason they have to come from the same row is + // that sampling random numbers from a contiguous random number sequence + // is much more efficient than jumping around, and the linear offset of + // each element of Z (the global matrix) maps to an offset in a random + // number sequence. for Z, the end of a row and the beginning of the + // next have adjacent offsets, but for Zij (tile of global matrix), this + // is not necessarily the case. + // We must fill the entire `zij` shmem with values (even out of bounds + // on the K-dimension) otherwise we can get NaNs during the GEMM + const int kQueriesPerBlock = kBlockSizeI; + const int threads_per_row = cutlass::fast_min( + kNumThreads / kQueriesPerBlock, (int64_t)num_keys_in_block); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); + + const int thread_i = thread_id / threads_per_row; + const int thread_start_j = + (thread_id % threads_per_row) * elts_per_thread; + + if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + (query_start + thread_i) * p.num_keys + + (key_start + thread_start_j), + &curand_state); + + // generate elements of Zij, 4 elements at a time + for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + num_keys_in_block); + zij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + // we'll write Zij transposed since attention is also transposed + // during the matmul to compute dV. + zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) = + (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob + ? scalar_t(dropout_scale) + : scalar_t(0); + } + } + } + __syncthreads(); +#if 0 + PRINT_TENSOR4x4_T0_L0("zij", zij); + PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4); +#endif + + // Save mask for later DOIVJ matmul + + int warp_idx_mn_0 = warp_id % + (MatmulDOIVJ::Mma::Base::WarpCount::kM * + MatmulDOIVJ::Mma::Base::WarpCount::kN); + auto output_tile_coords_doivj = cutlass::MatrixCoord{ + warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM}; + auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords_doivj); + MatmulDOIVJ::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m /*q*/, int accum_n /*k*/, int idx) { + if (zij.at({accum_n, accum_m}) == scalar_t(0)) { + dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0}; + } + }, + [&](int accum_m) {}); + } + __syncthreads(); + } + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradV matmul + // + // grad_v[j_start:j_end] += attn_T @ do_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + constexpr bool kSingleIterationGradV = + kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int32_t col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + col += MatmulGradV::ThreadblockShape::kN) { + using Mma = typename MatmulGradV::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, p.head_dim_value - col, num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradV::OutputTileIterator( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, + {num_keys_in_block, p.head_dim_value - col}, + thread_id); + }; + typename Mma::IteratorB iterator_B( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM + col), + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi + Mma mma( + // operand A: Pij.T + shared_storage.attn_shared_storage().accum_ref(), + // operand A_scale Zij.T: + // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T + // which is computed on the fly as fragments of Pij.T are loaded in + shared_storage.zij().accum_ref(), + // operand B: dOi - which was loaded into shared memory previously + // when we computed dVj + shared_storage.mm_gradV().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradV::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradV) { + output_frags.gradV.clear(); + } else { + gmem_tile.load(output_frags.gradV, thread_id); + } + } + mma.set_prologue_done(kPrologueGV); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradV, + iterator_B, + output_frags.gradV); + __syncthreads(); + if (kPrologueGV && !kSingleIterationGradV && + col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { + prologueGradV(col + MatmulGradV::ThreadblockShape::kN); + } + + if (!kOutputInRF) { + if (kNeedsAccumGradV && !isLastQuery) { + gmem_tile.store(output_frags.gradV, thread_id); + } else { + accumulateInGmem( + shared_storage.gradV_epilogue(), + output_frags.gradV, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); + } + } + } + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulDOIVJ + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulDOIVJ::Mma; + // do_i + typename Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + const_cast(p.grad_output_ptr + query_start * p.gO_strideM), + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + + // v_j.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + const_cast(p.value_ptr + key_start * p.v_strideM), + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + + Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); + mma.set_prologue_done(kPrologueDOV); + mma.set_zero_outside_bounds(!skipBoundsChecks); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + if (kPrologueGQ) { + prologueGradQ(0); + } + if (kPrologueGK) { + prologueGradK(0); + } + + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + // TODO: This must be terribly inefficient. There must be a better way + // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] + // attn_shared_storage [smem] <- tmp.T + // tmp_shared_storage [smem] <- tmp + { + using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator; + auto lane_offset = LambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + // if dropout was used, compute dPij = dPij_dropped * Zij + if (kApplyDropout) { + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (dropout_keep_mask_doivj[idx].get()) { + accum[idx] *= dropout_scale; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) {}); + } + + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); +#if 0 + PRINT_B0_T0("doivj_dropped"); + print_warp_accum(accum, lane_offset, 4, 4); + PRINT_TENSOR4x4_T0_L0("attn_T", attn_T) +#endif + accum_t current_di; + // dSij = (dPij - Di) * Pij + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + // TODO: Otherwise we can get nans as we + // might have infs here (only seen on f16 tho) + if (skipBoundsChecks || + (accum_m < num_queries_in_block && + accum_n < num_keys_in_block)) { + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) { + + }); + + // store bias gradient tile dBij to global memory, + // where dBij = dSij = Pij * (dPij - Di) + if (p.grad_bias_ptr != nullptr) { + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator + output_iter( + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator:: + Params{p.gB_strideM}, + // grad_bias_ptr is offset to point at beginning of + // matrix of shape (queries, keys) for a given + // (batch_id, head_id) the pointer arithmetic here produces + // a pointer to the start of the current tile within that + // matrix + p.grad_bias_ptr + query_start * p.gB_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + + // no-op epilogue operator - just casting and storing contents of + // accum to global memory + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue epilogue( + shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); + epilogue(output_op, output_iter, accum, output_iter); + } + + accum = accum * scale; + +#if 0 + PRINT_B0_T0("(doivj - di) * attn * scale"); + print_warp_accum(accum, lane_offset, 4, 4); +#endif + + __syncthreads(); + if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { + auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); + // attn <- attn_T.T + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); + }, + [&](int accum_m) {}); + } + } + + MatmulDOIVJ::B2bGemm::accumToSmem( + shared_storage.tmp_shared_storage(), + accum, + lane_id, + output_tile_coords); + __syncthreads(); + } + // Force `nvcc` to recompute values that depend on the variables just below + // to use less RF and prevent some spilling + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradQ matmul + // + // grad_q[i_start:i_end] += tmp @ k_j + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = + kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); + col += MatmulGradQ::ThreadblockShape::kN) { + using Mma = typename MatmulGradQ::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_queries_in_block, + false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, + num_keys_in_block); + + // k_j + typename Mma::IteratorB iterator_B( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM + col), + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto a = shared_storage.tmp_shared_storage().accum_ref(); + Mma mma( + // operand A: dSij + shared_storage.tmp_shared_storage().accum_ref(), + // operand B: Kj + shared_storage.mm_gradQ().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int num_cols = kSingleIterationGradQ + ? 1 + : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + + if (p.num_splits_key_device() > 1) { + AtomicLock::acquire( + &p.workspace_gq[storage_id].lock, + p.split_key_device() + 1, + thread_id); + // Make sure we can see other block's output + __threadfence(); + } + + AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]}; + if (!kNeedsAccumGradQ || + (p.num_splits_key_device() == 1 && key_start == 0)) { + // if we know we are the first to access it, we know it's only zeros. + // Avoids a load from gmem (and gmem init as well) + accum.clear(); + } else { + gmem_tile.load(accum, thread_id); + } + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + mma.set_prologue_done(kPrologueGQ); + mma(gemm_k_iterations, accum, iterator_B, accum); + __syncthreads(); + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); + if (kPrologueGQ && !isLastColumn) { + prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); + } + + bool isLast = [&]() { + int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + if (p.num_keys <= next_key) { + return true; + } + if (query_start < getSmallestQueryForKey(p, next_key)) { + return true; + } + return false; + }(); + // Output results + if (p.num_splits_key_device() > 1) { + int32_t numAddsSoFar = -1; + if (isLast && thread_id == 0) { + numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) + + 1; // `atomicAdd` returns the old value + } + isLast = __syncthreads_or( + numAddsSoFar == getNumParallelBlocksForQuery(p, query_start)); + assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start)); + } + if (kNeedsAccumGradQ && !isLast) { + gmem_tile.store(accum, thread_id); + if (p.num_splits_key_device() > 1) { + // Make sure everyone wrote before we release the lock + __threadfence(); + __syncthreads(); + AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id); + } + } else { + // NOTE: We're not releasing the lock because no one is expected + // to come after us (we're the last one to write) + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + // if `direct_store` is True, we store to gmem (`*gmem = accum`) + // otherwise, we accumulate in gmem (`*gmem = *gmem + accum`) + // If we know ahead of time when we will write for the first time + // we can: + // (1) Avoid an additional memory read + // (2) Avoid the cost of initializing memory to 0 + bool direct_store = kNeedsAccumGradQ || key_start == 0 || + (p.num_splits_key_device() > 1); + accumulateInGmem( + isLastColumn ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + direct_store, + warp_id, + lane_id); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradK matmul + // + // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = + kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); + col += MatmulGradK::ThreadblockShape::kN) { + using Mma = typename MatmulGradK::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, + num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradK::OutputTileIterator( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, + thread_id); + }; + + // q_i + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM + col), + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; + auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; + // this is basically: + // opA = kIsTransposedA ? getTmp() : getTmpT(); + bool constexpr kIsTransposedA = + MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; + auto& opA = *call_conditional< + kIsTransposedA, + decltype(getTmp), + decltype(getTmpT)>::apply(getTmp, getTmpT, 0); + Mma mma( + // operand A: dSij.T + opA.accum_ref(), + // operand B: Qi + shared_storage.mm_gradK().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradK::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradK) { + output_frags.gradK.clear(); + } else { + gmem_tile.load(output_frags.gradK, thread_id); + } + } + mma.set_prologue_done(kPrologueGK); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradK, + iterator_B, + output_frags.gradK); + __syncthreads(); + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + if (kPrologueGK && !isLastColumn) { + prologueGradK(col + MatmulGradK::ThreadblockShape::kN); + } + + if (kPrologueQK && isLastColumn) { + int64_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + DISPATCH_BOOL( + next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); + } + + // Output results + if (!kOutputInRF) { + if (kNeedsAccumGradK && !isLastQuery) { + gmem_tile.store(output_frags.gradK, thread_id); + } else { + accumulateInGmem( + isLastColumn ? shared_storage.gradK_epilogue_final() + : shared_storage.gradK_epilogue(), + output_frags.gradK, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); + } + } + } + } + + static CUTLASS_HOST_DEVICE int64_t getQueryStartShift(Params const& p) { + if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { + return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); + } + return 0; + } + + // Iteration order logic + static CUTLASS_HOST_DEVICE int64_t + getQueryStart(Params const& p, int64_t key_start) { + return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); + }; + static CUTLASS_HOST_DEVICE int64_t getQueryEnd(Params const& p) { + return align_up(p.num_queries, kBlockSizeI); + }; + + static CUTLASS_HOST_DEVICE int64_t + getSmallestQueryForKey(Params const& p, int64_t key_start) { + if (p.custom_mask_type == NoCustomMask) { + return 0; + } + int64_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int64_t window_size = + p.window_size == 0 ? p.num_queries + p.num_keys : p.window_size; + + auto last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, (int64_t)p.num_keys) - 1; + int first_query = key_start - shift; + int last_query = last_key_for_block - shift + window_size - 1; + if (last_query < 0 || first_query >= p.num_queries) { + return getQueryEnd(p); // nothing to compute in this column + } + first_query = cutlass::fast_max(0, first_query); + return (first_query / kBlockSizeI) * kBlockSizeI; + }; + + // Returns how many kernel blocks will write to a given block in `grad_query` + // This is usually equal to the number of key splits, but can be different + // for instance in the causal case, or varying seqlen + static CUTLASS_HOST_DEVICE int32_t + getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { + int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); + if (p.custom_mask_type != NoCustomMask) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int32_t last_query_for_block = + cutlass::fast_min(query_start + kBlockSizeI, p.num_queries) - 1; + int32_t last_key_for_block = + cutlass::fast_min(last_query_for_block + shift, p.num_keys - 1); + int32_t first_key_for_block = p.window_size == 0 + ? 0 + : cutlass::fast_max(query_start - p.window_size + 1 + shift, 0); + + if (p.window_size == 0) { + num_key_blocks = last_key_for_block / kBlockSizeJ + 1; + } else { + num_key_blocks = (last_key_for_block / kBlockSizeJ) - + (first_key_for_block / kBlockSizeJ) + 1; + } + + if (last_key_for_block < 0 || first_key_for_block >= p.num_keys) { + num_key_blocks = 0; + } + } + return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); + }; + + // Returns the next block to process + static CUTLASS_HOST_DEVICE void incrIteration( + Params const& p, + int64_t query_start, + int64_t key_start, + int64_t& next_query, + int64_t& next_key) { + next_query = query_start + kBlockSizeI; + next_key = key_start; + auto query_shift = getQueryStartShift(p); + // Wrap around + if (query_shift) { + if (next_query >= p.num_queries) { + next_query = getSmallestQueryForKey(p, key_start); + return; + } else if (query_start < query_shift && query_shift <= next_query) { + // jump to next key + } else { + return; + } + } else { + if (p.window_size > 0) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + // last key that is not masked out + int last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, (int64_t)p.num_keys) - 1; + int last_query = last_key_for_block - shift + p.window_size - 1; + if (next_query <= last_query && next_query < p.num_queries) { + return; + } + } else if (next_query < p.num_queries) { + return; + } + // jump to next key + } + // Next key + next_key = key_start + p.num_splits_key_device() * (int64_t)kBlockSizeJ; + next_query = getQueryStart(p, next_key); + } + + template + static CUTLASS_DEVICE void prologueQkNextIteration( + SharedStorage& shared_storage, + Params const& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + if (query_start >= p.num_queries || key_start >= p.num_keys) { + return; + } + + static constexpr bool kReloadK = + kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; + int thread_id = 32 * warp_id + lane_id; + typename MatmulQK::Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + const_cast(p.key_ptr + key_start * p.k_strideM), + {p.num_keys - key_start, p.head_dim}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + typename MatmulQK::Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + const_cast(p.query_ptr + query_start * p.q_strideM), + {p.head_dim, p.num_queries - query_start}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + MatmulQK::Mma::prologue( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + iterator_A, + iterator_B, + thread_id, + p.head_dim); + } + + template + static CUTLASS_DEVICE void writeFragsToGmem( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + uint16_t thread_id = 32 * warp_id + lane_id; + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : cutlass::fast_min( + MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + typename MatmulGradV::OutputTileIterator outputV_it( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), + {num_keys_in_block, p.head_dim_value}, + thread_id); + accumulateInGmem( + shared_storage.gradV_epilogue_final(), + output_frags.gradV, + outputV_it, + true, + warp_id, + lane_id); + + typename MatmulGradK::OutputTileIterator outputK_it( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, + thread_id); + accumulateInGmem( + shared_storage.gradK_epilogue_final(), + output_frags.gradK, + outputK_it, + true, + warp_id, + lane_id); + } + + template + static CUTLASS_DEVICE void accumulateInGmem( + typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, + typename MatmulT::Mma::FragmentC const& accum, + typename MatmulT::OutputTileIterator output_it, + bool first, + uint8_t warp_id, + uint8_t lane_id) { + using DefaultEpilogue = typename MatmulT::DefaultEpilogue; + using DefaultOutputOp = typename MatmulT::DefaultOutputOp; + using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; + DISPATCH_BOOL( + first, kIsFirst, ([&]() { + static constexpr auto ScaleType = kIsFirst + ? cutlass::epilogue::thread::ScaleType::Nothing + : cutlass::epilogue::thread::ScaleType::NoBetaScaling; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + typename DefaultOutputOp::ElementOutput, + DefaultOutputOp::kCount, + typename DefaultOutputOp::ElementAccumulator, + typename DefaultOutputOp::ElementCompute, + ScaleType>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MatmulT::OutputTileIterator, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale({1, 1}); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); + epilogue(rescale, output_it, accum, output_it); + })); + } + + template + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) { + // Each thread computes one value for Delta + // Depending on warp configuration, we might have multiple + // threads of the same warp working on the same row + using AccessType = cutlass::Array; + static_assert(kNumThreads >= kBlockSizeI, ""); + static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; + int16_t thread_id = 32 * warp_id + lane_id; + + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); + int16_t laneRow = thread_id / kNumThreadsPerLine; + bool rowPred = (query_start + laneRow) < p.num_queries; + bool pred = rowPred; + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + + laneFirstCol); + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + + laneFirstCol); + + static constexpr int64_t kMaxIters = + kMaxK / (kElementsPerAccess * kNumThreadsPerLine); + constexpr int kPipelineStages = 2; + accum_t delta_value = accum_t(0); + using GlobalLoad = + cutlass::arch::global_load; + AccessType frag_grad_output[kPipelineStages]; + AccessType frag_output[kPipelineStages]; + + auto loadAndIncrement = [&](int ld_pos, bool is_valid) { + frag_grad_output[ld_pos].clear(); + frag_output[ld_pos].clear(); + GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); + GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); + grad_output_ptr += kNumThreadsPerLine; + output_ptr += kNumThreadsPerLine; + }; + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kPipelineStages - 1; ++iter) { + int ld_pos = iter % kPipelineStages; + pred = pred && + (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < + p.head_dim_value; + loadAndIncrement(ld_pos, pred); + } + auto columnIteration = [&](int iter) { + // Load for next iter + int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; + pred = pred && + (laneFirstCol + + (iter + kPipelineStages - 1) * kElementsPerAccess * + kNumThreadsPerLine) < p.head_dim_value; + loadAndIncrement(ld_pos, pred); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccessType::kElements; ++i) { + delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * + accum_t(frag_grad_output[iter % kPipelineStages][i]); + } + }; + + // If we have a small lower-bound for K, we can unroll the loop + if (kMaxK <= 256) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kMaxIters; ++iter) { + columnIteration(iter); + } + } else { + int num_iters = + ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * + (kElementsPerAccess * kNumThreadsPerLine); + for (int iter = 0; iter < num_iters; ++iter) { + columnIteration(iter); + } + } + + // Reduce between workers + static_assert( + kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || + kNumThreadsPerLine == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kNumThreadsPerLine; i *= 2) { + delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); + } + + // Store in gmem + if (rowPred) { + p.delta_ptr[query_start + laneRow] = delta_value; + } + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); + +} // namespace PyTorchMemEffAttention diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..d2e53e9dfadbf9103052c07e74b1fcda7be2692f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -0,0 +1,1353 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +using namespace gemm_kernel_utils; + +namespace PyTorchMemEffAttention { +namespace { +template +constexpr int getWarpsPerSmFw() { + return ( + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return !signbit(value) + ? __int_as_float(atomicMax((int *)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int *)addr, __float_as_uint(value))); +} +} // namespace + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock_, + int kKeysPerBlock_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout_ = true, + bool kSupportsBias_ = true> +struct AttentionKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsDropout = kSupportsDropout_; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr int kQueriesPerBlock = kQueriesPerBlock_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && kIsHalf; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSmFw() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + const scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim] + const scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] + const scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] + const scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] + const int32_t* seqstart_q_ptr = nullptr; + const int32_t* seqstart_k_ptr = nullptr; + + const int32_t* seqlen_k_ptr = nullptr; + uint32_t causal_diagonal_offset = 0; + + // Output tensors + output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value] + // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr = nullptr; + // [num_heads, num_queries] - can be null + lse_scalar_t* logsumexp_ptr = nullptr; + + // Sliding window. ignored if == 0 + int32_t window_size = 0; + + // Scale + accum_t scale = 0.0; + + // Dimensions/strides + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; + + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; + int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; + int64_t bias_strideH = 0; + + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; + int64_t bias_strideB = 0; + + int32_t num_batches = 0; + int32_t num_heads = 0; + + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); + int64_t* extragraph_offset = nullptr; + int64_t* seed = nullptr; + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } + + int64_t q_start = 0, k_start = 0; + // Advance to current batch - in case of different sequence lengths + if (seqstart_q_ptr != nullptr) { + assert(seqstart_k_ptr != nullptr); + seqstart_q_ptr += batch_id; + + q_start = seqstart_q_ptr[0]; + int64_t q_next_start = seqstart_q_ptr[1]; + int64_t k_end; + seqstart_k_ptr += batch_id; + + if (seqlen_k_ptr) { + k_start = seqstart_k_ptr[0]; + k_end = k_start + seqlen_k_ptr[batch_id]; + } else { + k_start = seqstart_k_ptr[0]; + k_end = seqstart_k_ptr[1]; + } + + num_queries = q_next_start - q_start; + num_keys = k_end - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += + int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (custom_mask_type == CausalFromBottomRight) { + causal_diagonal_offset = num_keys - num_queries; + } + // We use num_keys_absolute to index into the rng_state + // We need this index to match between forward and backwards + num_keys_absolute = num_keys; + if (custom_mask_type == CausalFromTopLeft || + custom_mask_type == CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + num_keys = cutlass::fast_min( + int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), + num_keys); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 && + logsumexp_ptr == nullptr && window_size == 0) { + if (head_id % kQueriesPerBlock != 0) { + return false; + } + q_strideM = q_strideH; + bias_strideM = bias_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + custom_mask_type = NoCustomMask; + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + // Only worth doing if they could have been modified above. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + o_strideM = warp_uniform(o_strideM); + custom_mask_type = warp_uniform(custom_mask_type); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3( + ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + static_assert( + MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert( + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + TORCH_CHECK( + p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kAlignmentQ, "."); + TORCH_CHECK( + p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideH). " + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kAlignmentQ, "."); + TORCH_CHECK( + p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a " + "multiple of ", kAlignmentQ, "."); + } + TORCH_CHECK( + p.q_strideM % kAlignmentQ == 0, + "query is not correctly aligned (strideM)"); + TORCH_CHECK( + p.k_strideM % kAlignmentK == 0, + "key is not correctly aligned (strideM)"); + TORCH_CHECK( + p.v_strideM % kAlignmentV == 0, + "value is not correctly aligned (strideM)"); + TORCH_CHECK( + p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned (strideH)"); + TORCH_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "invalid value for `custom_mask_type`"); + if (p.window_size > 0) { + TORCH_CHECK( + p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight, + "custom_mask_type not supported"); + } + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{ + (int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + curandStatePhilox4_32_10_t curand_state_init; + if (kSupportsDropout && p.use_dropout) { + const auto [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs); + if (p.rng_engine_inputs.captured_) { + // See Note [Seed and Offset Device] + // When we are in cuda graph capture mode the seed and offset are stored + // on device We pass in int64_t* seed, and int64_t* offset to act as + // scratch space for storing the rng state during the forward pass and + // saving for backwards. + *p.seed = seed; + *p.extragraph_offset = offset; + } + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + seed, + 0, + offset + p.dropout_batch_head_rng_offset, + &curand_state_init); + } + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + if (p.window_size > 0) { + // don't compute anything if below attention band + if (iter_key_start + kKeysPerBlock < + int32_t(query_start + p.causal_diagonal_offset) - p.window_size) { + continue; + } + } + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + const_cast(p.value_ptr + iter_key_start * p.v_strideM), + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{ + tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + const_cast(p.query_ptr), + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + const_cast(p.key_ptr + iter_key_start * p.k_strideM), + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_uniform(warp_id()); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + const_cast(p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start), + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + // This is only needed if upper-right corner of current query / key block + // intersects the mask Coordinates of upper-right corner of current block + // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The + // first masked element is x = y + offset -> query_start + offset There is + // intersection (and we need to mask) if min(iter_key_start + + // kKeysPerBlock, num_keys)) >= query_start + offset + if (p.custom_mask_type && + cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= + (query_start + p.causal_diagonal_offset)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + // last absolute col is (last absolute query + offset) + // last local col is (last absolute query + offset - + // iter_key_start) + last_col = query_start + accum_m + p.causal_diagonal_offset - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + + // Mask out lower left corner of block if window_size > 0 + // only required if current block intersects with the lower left corner + // block starts at x_lowerleft = iter_key_start // y = query_start + + // kQueriesPerBlock first non masked value at this y is : x_first = + // query_start + kQueriesPerBlock - window_size mask if x_fist > + // x_lowerleft + + if (p.window_size > 0 && + (query_start + p.causal_diagonal_offset + + cutlass::fast_min( + int32_t(kQueriesPerBlock), int32_t(p.num_queries)) - + p.window_size >= + iter_key_start)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t first_col; + const int32_t offset = query_start + p.causal_diagonal_offset - + p.window_size - iter_key_start; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { first_col = accum_m + offset; }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n <= first_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + // print_warp_accum(accum, lane_offset, 12, + // 12); + } + + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + my_lane_id, + thread_id(), + my_warp_id, + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (kSupportsDropout && p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys_absolute + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kSingleValueIteration + ? 1 + : ceil_div( + (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + const_cast(p.value_ptr + iter_key_start * p.v_strideM), + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)my_warp_id, + (int)my_lane_id); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + int first_key = 0; + if (p.window_size > 0) { + first_key = (cutlass::fast_max( + int(query_start + p.causal_diagonal_offset) - + p.window_size + 1, + 0) / + kKeysPerBlock) * + kKeysPerBlock; + } + + // int first_key_block = 0; + // MM1::Mma::drain_cp_asyncs(); # TODO figure out if this is needed for correctness + DISPATCH_BOOL( + iter_key_start == first_key, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = call_conditional< + kIsLast, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + my_warp_id, + my_lane_id); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (thread_id() < p.num_queries) { + // We set fully masked out rows to 0, the sumexp for masked out rows will be 0 + // We update it to be 1 prior to calling log so that log(1) = 0 + s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()]; + mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits::infinity()) ? 0: mi[thread_id()]; + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x; + } + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.y; + } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); + +} // namespace PyTorchMemEffAttention diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h new file mode 100644 index 0000000000000000000000000000000000000000..384fc9130ef21535500c6639164eb1aa75bba493 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h @@ -0,0 +1,914 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// This file is auto-generated. See "generate_kernels.py" +#pragma once +#include +using namespace PyTorchMemEffAttention; +// ======== f16 / sm70 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm70(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70); +} + +// ======== bf16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_bf16_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k96_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80); +} + +// ======== f16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k96_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80); +} + +// ======== f16 / sm50 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm50(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50); +} + +// ======== f32 / sm50 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM) +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +#else +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +#endif +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm50(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm50); +#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM) + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50); +#else + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50); +#endif + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50); +} + +// ======== f32 / sm70 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm70(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70); +} + +// ======== f16 / sm75 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm75(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75); +} + +// ======== f32 / sm75 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm75(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75); +} + +// ======== f32 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80); +} + + +template +void dispatch_cutlassB(T cb, int cc = 0) { + + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassB_f16_sm70(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_bf16_sm80(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_f16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassB_f16_sm50(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassB_f32_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassB_f32_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassB_f16_sm75(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassB_f32_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_f32_sm80(cb, cc); + } +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h new file mode 100644 index 0000000000000000000000000000000000000000..1d2191990fd7467974ab88514b4e1b9d3ecfca67 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h @@ -0,0 +1,313 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// This file is auto-generated. See "generate_kernels.py" +#pragma once +#include +using namespace PyTorchMemEffAttention; +// ======== bf16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_bf16_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_32x128_gmem_sm80); +} + +// ======== f16 / sm50 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm50(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm50); +} + +// ======== f16 / sm70 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm70(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm70); +} + +// ======== f16 / sm75 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm75(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm75); +} + +// ======== f16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm80); +} + +// ======== f32 / sm50 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm50(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm50); +} + +// ======== f32 / sm70 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm70(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm70); +} + +// ======== f32 / sm75 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm75(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm75); +} + +// ======== f32 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm80); +} + + +template +void dispatch_cutlassF(T cb, int cc = 0) { + + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_bf16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassF_f16_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassF_f16_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassF_f16_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_f16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassF_f32_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassF_f32_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassF_f32_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_f32_sm80(cb, cc); + } +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6c494919ee8258f57954be76bf2e08b68de37dc5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +#include +#include + + +template +struct CutlassToAtenDtype; + +template <> +struct CutlassToAtenDtype { + using scalar_t = cutlass::half_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Half; + } +}; + +template <> +struct CutlassToAtenDtype { + using scalar_t = cutlass::bfloat16_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::BFloat16; + } +}; + +template <> +struct CutlassToAtenDtype { + using scalar_t = float; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Float; + } +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..5ebd65460c6f832d18814aa540c7b31dba151115 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/sdp_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/sdp_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d1c6fd707bb8dc58d41ae17554131fb218e6c58a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/sdp_utils.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include + +namespace sdp { + +bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug); +SDPBackend select_sdp_backend(sdp_params const& kernel_params); +C10_EXPORT bool is_flash_attention_available(); +C10_EXPORT bool can_use_flash_attention(sdp_params const& params, bool debug); +C10_EXPORT bool can_use_mem_efficient_attention(sdp_params const& params, bool debug); +C10_EXPORT bool can_use_cudnn_attention(sdp_params const& params, bool debug); + +} // namespace sdp diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/aotriton_adapter.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/aotriton_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..b38122248db80e38df92c2b5c4a1cab2725cc8c9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/aotriton_adapter.h @@ -0,0 +1,126 @@ +#pragma once + +#ifdef USE_ROCM + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h +//////////////////////////////////////////////////////////////////////////////// + +namespace sdp { + +namespace aotriton_adapter { + +inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) +{ +#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname + CAST_TYPE(kByte, kUInt8); + CAST_TYPE(kUInt16, kUInt16); + CAST_TYPE(kUInt32, kUInt32); + CAST_TYPE(kUInt64, kUInt64); + CAST_TYPE(kChar, kInt8); + CAST_TYPE(kShort, kInt16); + CAST_TYPE(kInt, kInt32); + CAST_TYPE(kLong, kInt64); + CAST_TYPE(kHalf, kFloat16); + CAST_TYPE(kFloat, kFloat32); + CAST_TYPE(kBFloat16, kBFloat16); + return aotriton::DType::kUnknown; +#undef CAST_TYPE +} + +template +struct IntArrayRefCaster { + // std::array cast(IntArrayRef); +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ static_cast(ref.at(0)) }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)), + static_cast(ref.at(3)) + }}; + } +}; + + +template +aotriton::TensorView mk_aotensor(const at::Tensor& q, std::string_view tensor_name) +{ + const auto strides = q.strides(); + int real_rank = strides.size(); + if (real_rank != Rank) { // Lazy convertion of tensor_name + TORCH_CHECK(false, + std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + + " but is " + std::to_string(real_rank)); + } + return aotriton::TensorView(reinterpret_cast(q.data_ptr()), + IntArrayRefCaster::cast(q.sizes()), + IntArrayRefCaster::cast(strides), + cast_dtype(q.dtype())); +} + +inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q) +{ + return aotriton::TensorView<0>(reinterpret_cast(q.data_ptr()), + cast_dtype(q.dtype())); +} + +inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kUInt64); // AOTriton accepts unsigned int64 +} + +inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kInt32); +} + +} // namespace aotriton_adapter + +} // namespace sdp + +namespace at::native { + +inline int64_t ceil_div(int64_t numerator, int64_t denominator) { + return (numerator + (denominator - 1)) / denominator; +} + +} + +#endif // USE_ROCM diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h new file mode 100644 index 0000000000000000000000000000000000000000..6fd46467bc0761bcfedcc880ecbd9b1783ffb310 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h @@ -0,0 +1,67 @@ +#pragma once +#include + +#include + +#if defined(USE_CK_FLASH_ATTENTION) +namespace pytorch_flash { + +std::tuple< + at::Tensor, // output + at::Tensor, // q + at::Tensor, // k + at::Tensor, // v + at::Tensor, // lse + at::Tensor, // seed + at::Tensor, // offset + at::Tensor> // dropout randval +mem_eff_forward_ck( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + float p_dropout, + bool return_dropout_randval, + std::optional is_causal, + std::optional scale, + const std::optional& attn_bias_, + std::optional& out_, + const std::optional& cu_seqlens_q, + const std::optional& cu_seqlens_k, + const std::optional& seqstart_q, + const std::optional& seqstart_k, + std::optional gen_, + std::optional& seqused_k_ +); + +std::tuple< + at::Tensor, // dQ + at::Tensor, // dK + at::Tensor, // dV + at::Tensor> // dBias +mem_eff_backward_ck( + const at::Tensor &dout, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const at::Tensor &softmax_lse, + const at::Tensor &dq_, + const at::Tensor &dk_, + const at::Tensor &dv_, + std::optional &attn_bias, + bool bias_requires_grad, + std::optional &grad_bias, + std::optional &cu_seqlens_q, + std::optional &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float scale, + bool is_causal, + bool deterministic, + bool zero_tensors, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +} // namespace pytorch_flash +#endif // USE_CK_FLASH_ATTENTION diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 0000000000000000000000000000000000000000..17298aae9485daf370c8b16ece25ba542b36bb81 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,599 @@ +#pragma once +#include + +#include +#include +#include + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace pytorch_flash { + +// AOTriton Implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_aot( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_aot( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& block_table_, + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple mha_bwd_aot( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +std::tuple mha_varlen_bwd_aot( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +#if defined(USE_CK_FLASH_ATTENTION) +// CK implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_ck( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); // batch_size x nheads x seqlen_q x seqlen_k + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_ck( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); + +std::tuple mha_bwd_ck( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + attn_bias_, // batch_size x num_heads x seqlen_q x seqlen_k + bool bias_requires_grad, + std::optional& grad_bias, + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple mha_varlen_bwd_ck( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& attn_bias_, // num_heads or b x num_heads + bool bias_requires_grad, + std::optional& grad_bias, + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#endif + +TORCH_API +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} + +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& + block_table_, // Not used on ROCm. Keeping for parity with CUDA + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional dummy_attn_bias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + return mha_varlen_fwd_ck( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_varlen_fwd_aot( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} + +inline std::tuple mha_bwd( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { +#if defined(USE_CK_FLASH_ATTENTION) + std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); +#else + TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); +#endif + } + return mha_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +inline std::tuple mha_varlen_bwd( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_varlen_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + non_null_window_left, + non_null_window_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } +#endif + return mha_varlen_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +} // namespace pytorch_flash diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/sdp_utils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/sdp_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b19ed00ecc88ee9b787c8a80dde0bac87cc26ccb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/sdp_utils.h @@ -0,0 +1,88 @@ +#pragma once +#include +#include + +namespace at::native { + +void alloc_with_matching_layout( + const Tensor& q, + Tensor& output, + const std::vector& shape) { + TORCH_INTERNAL_ASSERT( + shape.size() == q.sizes().size(), + "SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); + + if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) { + output = at::empty_like(q); + return; + } + + // get the "fill order," which is just an argsort on the strides + std::vector fill_order(shape.size()); + std::iota(fill_order.begin(), fill_order.end(), 0); + const auto q_strides = q.strides(); + std::stable_sort( + fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { + return q_strides[idx1] < q_strides[idx2]; + }); + std::vector ordered_strides(shape.size()); + int64_t current_stride = 1; + for (const int dim_idx : fill_order) { + ordered_strides[dim_idx] = current_stride; + current_stride *= shape[dim_idx]; + } + output = at::empty(at::IntArrayRef(shape), q.options()) + .as_strided( + at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0); +} + +void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { + const int dims = output.sizes().size(); + std::vector outer_to_inner(dims); + std::iota(outer_to_inner.begin(), outer_to_inner.end(), 0); + const auto o_strides = output.strides(); + std::stable_sort( + outer_to_inner.begin(), + outer_to_inner.end(), + [&o_strides](int idx1, int idx2) { + return o_strides[idx1] > o_strides[idx2]; + }); + std::vector inverse(dims); + for (int d = 0; d < dims; d++) { + inverse[d] = std::find(outer_to_inner.begin(), outer_to_inner.end(), d) - + outer_to_inner.begin(); + } + grad_output = grad_output.permute(at::IntArrayRef(outer_to_inner)) + .contiguous() + .permute(at::IntArrayRef(inverse)); +} + +bool same_strides(const Tensor& t1, const Tensor& t2) { + std::vector t1_strides_no_ones; + std::vector t2_strides_no_ones; + const auto t1strides = t1.strides(); + const auto t2strides = t2.strides(); + const int dim = t1strides.size(); + if (dim != (int)t2strides.size()) { + return false; + } + const auto t1sizes = t1.sizes(); + const auto t2sizes = t2.sizes(); + + // we are going through strides backward here, but if both are backward it's + // comparable + for (int i = 0; i < dim; i++) { + if (t1sizes[i] > 1) { + t1_strides_no_ones.push_back(t1strides[i]); + } + if (t2sizes[i] > 1) { + t2_strides_no_ones.push_back(t2strides[i]); + } + } + return std::equal( + t1_strides_no_ones.begin(), + t1_strides_no_ones.end(), + t2_strides_no_ones.begin(), + t2_strides_no_ones.end()); +} +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h new file mode 100644 index 0000000000000000000000000000000000000000..c63ca928613e67fb795fedc12e05d1b64ee3e558 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h @@ -0,0 +1,560 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace sdp { + +constexpr int32_t num_backends = at::num_sdp_backends; +using SDPBackend = at::SDPBackend; + +// Note that if this changed make sure to update +// the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h +enum class CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, +}; + +struct sdp_params { + at::Tensor query; + at::Tensor key; + at::Tensor value; + std::optional attn_mask; + double dropout; + bool is_causal; + bool enable_gqa; +}; + +SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params); + +inline c10::SymFloat calculate_scale( + const at::Tensor& query, + std::optional scale) { + const auto softmax_scale = scale.has_value() + ? scale.value() + : (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt())); + return c10::SymFloat(softmax_scale); +} + +inline bool input_requires_grad(sdp_params const& params) { + const bool any_inputs_require_grad = params.query.requires_grad() || + params.key.requires_grad() || params.value.requires_grad(); + const bool gradmode_enabled = at::GradMode::is_enabled(); + return any_inputs_require_grad && gradmode_enabled; +} + +inline bool has_for_nested_inputs(sdp_params const& params) { + return + (params.query.is_nested() && params.query.layout() == c10::kStrided) || + (params.key.is_nested() && params.key.layout() == c10::kStrided) || + (params.value.is_nested() && params.value.layout() == c10::kStrided); +} + +inline bool has_for_dense_inputs(sdp_params const& params) { + return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested(); +} + +inline bool has_only_dense_inputs(sdp_params const& params) { + return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested(); +} + +template +inline bool check_tensor_dtype( + sdp_params const& params, + dtype_vector allowed_dtypes, + bool debug) { + auto query_dtype = params.query.dtype(); + if (!(query_dtype == params.key.dtype() && + query_dtype == params.value.dtype() && + (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) != + allowed_dtypes.end()))) { + if (debug) { + TORCH_WARN( + "Expected query, key and value to all be of dtype: {", + c10::Join(", ", allowed_dtypes), + "}. Got ", + "Query dtype: ", + params.query.dtype(), + ", Key dtype: ", + params.key.dtype(), + ", and Value dtype: ", + params.value.dtype(), + " instead."); + } + return false; + } + return true; +} + + +inline bool try_broadcast_param_size( + const c10::SymInt q_size, + const c10::SymInt k_size, + const c10::SymInt v_size, + std::string_view param_name, + bool debug) { + auto max_size = std::max({q_size, k_size, v_size}); + if ((q_size != max_size && q_size != 1) || + (k_size != max_size && k_size != 1) || + (v_size != max_size && v_size != 1)) { + if (debug) { + TORCH_WARN( + "Both fused kernels require query, key and value to have broadcastable ", + param_name, + "got Query ", + param_name, + q_size, + ", Key ", + param_name, + k_size, + ", Value ", + param_name, + v_size, + " instead."); + } + return false; + } + return true; +} + +inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + at::Tensor const& param, + std::string_view param_name, + bool debug) { + const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param); + const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes(); + auto num_head_dims = nt_tensor_impl->opt_size(1); + if (!num_head_dims.has_value()) { + // num_head_dims is ragged + if (debug) { + TORCH_WARN( + "Fused kernels do not support ragged num_head_dims, ", + param_name, + "has a ragged num_heads."); + } + return false; + } + + auto* sizes_ptr = sizes.data_ptr(); + const int64_t n_tensors = param.size(0); + const int64_t size_tensor_stride = sizes.stride(0); + + // This is being called inside sdp with shape [batch, heads, {seq_len}, dim] + for (const auto i : c10::irange(n_tensors)) { + if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) { + if (debug) { + TORCH_WARN( + "Fused kernels do not support seq_len == 0, ", + param_name, + "has a seq len of 0."); + } + return false; + } + } + return true; +} + +inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) { + // When this function is called we are assured that the nt is dim==4 + bool q_is_safe = params.query.is_nested() + ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + params.query, "query ", debug) + : true; + // short circuit if any is unsafe + if (!q_is_safe) { + return false; + } + + bool k_is_safe = params.key.is_nested() + ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + params.key, "key ", debug) + : true; + if (!k_is_safe) { + return false; + } + + bool v_is_safe = params.value.is_nested() + ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper( + params.value, "value ", debug) + : true; + if (!v_is_safe) { + return false; + } + + // We now know none of the inputs have ragged num_heads, so we can safely + // access .size(1) + auto q_num_heads = params.query.size(1); + auto k_num_heads = params.key.size(1); + auto v_num_heads = params.value.size(1); + bool same_num_heads = + q_num_heads == k_num_heads && q_num_heads == v_num_heads; + + if (!same_num_heads) { + if (input_requires_grad(params)){ + if (debug) { + TORCH_WARN( + "Both fused kernels do not support training with broadcasted NT inputs."); + } + return false; + } + return try_broadcast_param_size( + q_num_heads, k_num_heads, v_num_heads, "num heads ", debug); + } + + return true; +} + +inline bool check_nested_tensor(sdp_params const& params, bool debug) { + // Return false if have nested tensor + if (!has_only_dense_inputs(params)) { + if (debug) { + TORCH_WARN( + "Both fused kernels of cpp version currently do not support Nested Tensor inputs."); + } + return false; + } + return true; +} + +inline bool check_for_dropout(sdp_params const& params, bool debug) { + if (params.dropout > 0.0) { + if (debug) { + TORCH_WARN("Both fused kernels do not support non-zero dropout."); + } + return false; + } + return true; +} + +inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) { + if (input_requires_grad(params)) { + if (debug) { + TORCH_WARN( + "Memory efficient attention currently doesn't support training with NT inputs."); + } + return false; + } + return true; +} + +inline bool check_for_attn_mask(sdp_params const& params, bool debug) { + if (params.attn_mask.has_value()) { + if (debug) { + TORCH_WARN("Flash Attention does not support non-null attn_mask."); + } + return false; + } + return true; +} + +inline bool check_attn_mask_shape(sdp_params const& params, bool debug) { + auto attn_mask = params.attn_mask; + if (!attn_mask.has_value()) { + return true; + } + if (attn_mask.value().requires_grad()) { + return false; + } + auto batchSize = params.query.sym_size(0); + auto qSize = params.query.sym_size(2); + auto kvSize = params.key.sym_size(2); + auto num_head = params.query.sym_size(1); + if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) { + return false; + } + if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) { + return false; + } + if (attn_mask.value().dim() == 2) { + return true; + } else if (attn_mask.value().dim() == 4) { + if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize) + && (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) { + return true; + } + } + if (debug) { + TORCH_WARN("Please use the following attn mask shapes: ", + "2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ", + "4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})"); + } + return false; +} + +inline bool check_tensor_shapes(sdp_params const& params, bool debug) { + auto query_dim = params.query.dim(); + if (!(query_dim == params.key.dim() && query_dim == params.value.dim() && + (query_dim == 4))) { + if (debug) { + TORCH_WARN( + "All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ", + query_dim, + ", Key dim: ", + params.key.dim(), + ", Value dim: ", + params.value.dim(), + " instead."); + } + return false; + } + return true; +} + +inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) { + const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param); + auto seq_len = nt_tensor_impl->opt_size(2); + if (!seq_len.has_value()) { + if (debug) { + TORCH_WARN( + "For both fused kernels, if one of key/value batch_size requires " + "broadcasting and the other does not, then the other must have a ", + "consistent seq_len dim.") + } + return false; + } + return true; +} + +template +inline bool check_grouped_query_attention(sdp_params const& params, bool debug) { + const auto q_num_heads = params.query.sym_size(-3); + const auto k_num_heads = params.key.sym_size(-3); + const auto v_num_heads = params.value.sym_size(-3); + const bool same_kv_heads = k_num_heads == v_num_heads; + + if (requires_same_num_heads && !(same_kv_heads)){ + if (debug) { + TORCH_WARN( + "Both fused kernels require key and value to have the same num_heads and batch_size but got: ", + "Key sizes: ", + params.key.sizes(), + ", Value sizes: ", + params.value.sizes(), + ", Query sizes: ", + params.query.sizes(), + " instead."); + } + return false; + } + // Check if grouped query attention is supported and validate the number of + // heads + if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0))) { + if (debug) { + TORCH_WARN( + "The number of heads in key/value must divide number of heads in query.", + "Got input Key sizes(): ", + params.key.sym_size(-3), + ", Value sizes(): ", + params.value.sym_size(-3), + ", Query sizes(): ", + params.query.sym_size(-3), + " instead."); + } + return false; + } + return true; +} + +template +inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) { + // This is expected to be called after check_tensor_shapes ensuring that the + // size() calls won't error since the inputs are all 4 dimensional + + auto q_batch_size = params.query.sym_size(0); + auto k_batch_size = params.key.sym_size(0); + auto v_batch_size = params.value.sym_size(0); + + bool same_batch_size = + q_batch_size == k_batch_size && q_batch_size == v_batch_size; + + auto q_num_heads = params.query.sym_size(-3); + auto k_num_heads = params.key.sym_size(-3); + auto v_num_heads = params.value.sym_size(-3); + + bool same_num_heads = + q_num_heads == k_num_heads && q_num_heads == v_num_heads; + + if (!same_batch_size){ + if(debug) { + TORCH_WARN( + "For dense inputs, both fused kernels require query, key and value to have the same batch_size. ", + "Query.sizes(): ", + params.query.sizes(), + ", Key.sizes(): ", + params.key.sizes(), + ", Value.sizes(): ", + params.value.sizes(), + " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel."); + } + return false; + } + + if(params.enable_gqa && supports_gqa){ + return check_grouped_query_attention(params, debug); + } + + // same num heads condition for non-gqa case + if (!same_num_heads){ + if (debug) { + TORCH_WARN( + "For dense input, both fused kernels require query, key and value to have the same num_heads. ", + "Query.sizes(): ", + params.query.sizes(), + ", Key sizes(): ", + params.key.sizes(), + ", Value sizes(): ", + params.value.sizes(), + " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel."); + } + return false; + } + // If all checks pass, return true + return true; +} + +inline bool check_batch_size_nested(sdp_params const& params, bool debug) { + // This is expected to be called after check_tensor_shapes ensuring that the + // size() calls won't error since the inputs are all 4 dimensional + auto q_batch_size = params.query.sym_size(0); + auto k_batch_size = params.key.sym_size(0); + auto v_batch_size = params.value.sym_size(0); + + bool same_batch_size = + q_batch_size == k_batch_size && q_batch_size == v_batch_size; + + // num_heads logic for nested input is checked in + // check_for_seq_len_0_nested_tensor as there is handling there to make sure + // num_heads is not ragged + bool broadcastable_batch_size = true; + if (!same_batch_size) { + if (input_requires_grad(params)){ + if (debug) { + TORCH_WARN( + "Both fused kernels do not support training with broadcasted NT inputs."); + } + return false; + } + // try to broadcast batchsize + broadcastable_batch_size = try_broadcast_param_size( + q_batch_size, k_batch_size, v_batch_size, "batch size ", debug); + + // if only one of k or v require broadcasting of batch size, the other + // must have a consistent seq_len dim + if (broadcastable_batch_size) { + if (k_batch_size == 1 && v_batch_size != 1 && + !check_safe_kv_broadcast(params.value, debug)) { + return false; + } + if (v_batch_size == 1 && k_batch_size != 1 && + !check_safe_kv_broadcast(params.key, debug)) { + return false; + } + } + } + return broadcastable_batch_size; +} + +inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) { + // In some cases people will pass in 0 sized tensors, this will + // cause the fused path to error with unaligned mask + bool zero_seq_len_q = params.query.sym_size(-2) == 0; + bool zero_seq_len_k = params.key.sym_size(-2) == 0; + if (zero_seq_len_q || zero_seq_len_k) { + if (debug) { + TORCH_WARN( + "All fused kernels do not support zero seq_len_q or seq_len_kv."); + } + return false; + } + return true; +} + +template +inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) { + // The stride checking for NestedTensors is done within the kernel + // And .contiguous will be called if needed + + // This function checks that the last dimension of the inputs to + // fused_attention have stride 1 + bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 && + params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1; + + // https://github.com/pytorch/pytorch/issues/116333 + // If the head_dim is size 1 the stride won't matter, but we + // check this condition before padding the head_dim to 1 + if (ignore_singleton_dim){ + qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1; + } + bool is_cpu = params.query.device().type() == c10::DeviceType::CPU; + bool mask_stride_equal_1 = params.attn_mask.has_value() + ? params.attn_mask.value().sym_stride(-1) == 1 + : true; + bool mask_stride_valid = is_cpu ? true : mask_stride_equal_1; + if (!(qkv_strides_equal_1 && mask_stride_valid)) { + if (debug) { + std::ostringstream message; + message + << "All fused kernels require the last dimension of the input to have stride 1. "; + message << "Got Query.stride(-1): " << params.query.sym_stride(-1) + << ", Key.stride(-1): " << params.key.sym_stride(-1) + << ", Value.stride(-1): " << params.value.sym_stride(-1); + + if (params.attn_mask.has_value()) { + message + << ", Attn_mask.stride(-1): " + << params.attn_mask.value().sym_stride(-1) + << " (GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not)."; + } + TORCH_WARN(message.str()); + } + + return false; + } + return true; +} + +inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) { + // We check the global context to see if user has explicitly turned of flash + // sdp kernels + if (!at::globalContext().userEnabledFlashSDP()) { + if (debug) { + TORCH_WARN("Flash attention has been runtime disabled."); + } + return false; + } + return true; +} + +inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) { + // We check the global context to see if user has explicitly turned of + // mem_efficient sdp kernels + if (!at::globalContext().userEnabledMemEfficientSDP()) { + if (debug) { + TORCH_WARN("Memory Efficient attention has been runtime disabled."); + } + return false; + } + return true; +} + + +} // namespace sdp diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/Factory.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/Factory.h new file mode 100644 index 0000000000000000000000000000000000000000..8c70435bd313c4d46a2a7248101fd4b5c3240be0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/Factory.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at::native::mobile { + +Tensor allocate_padded_contiguous_if_needed( + const Tensor& input, + c10::MemoryFormat memory_format); + +// TODO: Remove this function when at::native::empty() is modified to accept a +// custom memory allocator. + +at::Tensor empty_with_tail_padding( + IntArrayRef size, + const caffe2::TypeMeta dtype, + c10::MemoryFormat memory_format, + std::optional maybe_names); + +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/ParamUtils.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/ParamUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c9088c03d81c18bd80eda81e88a172491415a7c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/ParamUtils.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { + +template +inline std::vector _expand_param_if_needed( + ArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + if (list_param.size() == 1) { + return std::vector(expected_dim, list_param[0]); + } else if ((int64_t)list_param.size() != expected_dim) { + std::ostringstream ss; + ss << "expected " << param_name << " to be a single integer value or a " + << "list of " << expected_dim << " values to match the convolution " + << "dimensions, but got " << param_name << "=" << list_param; + TORCH_CHECK(false, ss.str()); + } else { + return list_param.vec(); + } +} + +inline std::vector expand_param_if_needed( + IntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +inline std::vector expand_param_if_needed( + SymIntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/ParamsHash.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/ParamsHash.h new file mode 100644 index 0000000000000000000000000000000000000000..6b7894cb8549f59d34b9b52d660780b729ada575 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/native/utils/ParamsHash.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +// Hashing machinery for Params +// Fowler–Noll–Vo hash function +// see +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct ParamsHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + size_t operator()(const Params& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(Params))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +template +struct ParamsEqual { + // Params must be a POD because we read out its memory + // contents as char* when comparing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + bool operator()(const Params& a, const Params& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(Params)) == 0; + } +}; + +// Provide explicit byte-for-byte constructors to avoid uwittingly leaving +// padding bytes unitialized (e.g., when passing Params by value) +template +struct ParamsWrapper { + T pod; + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + ParamsWrapper() { + memset(&(this->pod), 0, sizeof(this->pod)); + } + + ParamsWrapper(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper& operator=(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + ParamsWrapper& operator=(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + inline friend bool operator==( + const ParamsWrapper& lhs, + const ParamsWrapper& rhs) noexcept { + auto ptr1 = reinterpret_cast(&(lhs.pod)); + auto ptr2 = reinterpret_cast(&(rhs.pod)); + return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0; + } +}; + +// Wrapped version: this allows the outer struct to have custom copy and move +// constructors for additional safety +template +struct ParamsWrapperHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + size_t operator()(const ParamsWrapper& params_wrapper) const { + auto ptr = reinterpret_cast(&(params_wrapper.pod)); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(params_wrapper.pod))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +} // namespace at::native diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_assert_tensor_metadata.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_assert_tensor_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..2636362bf581dce814e5509f04d534dec10cd69f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_assert_tensor_metadata.h @@ -0,0 +1,48 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> () +inline void _assert_tensor_metadata(const at::Tensor & a, at::OptionalIntArrayRef size=::std::nullopt, at::OptionalIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::call(a, size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*size)) : ::std::nullopt, stride.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*stride)) : ::std::nullopt, dtype, device, layout); +} +namespace symint { + template >> + void _assert_tensor_metadata(const at::Tensor & a, at::OptionalIntArrayRef size=::std::nullopt, at::OptionalIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::call(a, size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*size)) : ::std::nullopt, stride.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*stride)) : ::std::nullopt, dtype, device, layout); + } +} + +// aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> () +inline void _assert_tensor_metadata_symint(const at::Tensor & a, at::OptionalSymIntArrayRef size=::std::nullopt, at::OptionalSymIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::call(a, size, stride, dtype, device, layout); +} +namespace symint { + template >> + void _assert_tensor_metadata(const at::Tensor & a, at::OptionalSymIntArrayRef size=::std::nullopt, at::OptionalSymIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::call(a, size, stride, dtype, device, layout); + } +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_cdist_forward_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_cdist_forward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..4a25f54f48fc128045478969b5d0f82aa9569193 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_cdist_forward_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API _cdist_forward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, double, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_cdist_forward"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor"; + static at::Tensor call(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode); +}; + +struct TORCH_API _cdist_forward_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, double, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_cdist_forward"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_chunk_cat_compositeexplicitautograd_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_chunk_cat_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5076f49f7a7e14d8f819fa9907700abc520b9110 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_chunk_cat_compositeexplicitautograd_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor _chunk_cat(at::TensorList tensors, int64_t dim, int64_t num_chunks); +TORCH_API at::Tensor & _chunk_cat_out(at::Tensor & out, at::TensorList tensors, int64_t dim, int64_t num_chunks); +TORCH_API at::Tensor & _chunk_cat_outf(at::TensorList tensors, int64_t dim, int64_t num_chunks, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_efficient_attention_forward_native.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_efficient_attention_forward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..7522d8baa78c9b734fcd1b71a9d21ab89e8424f9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_efficient_attention_forward_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple _efficient_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, ::std::optional max_seqlen_q, ::std::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp=false, ::std::optional scale=::std::nullopt, const ::std::optional & seqlen_k={}, ::std::optional window_size=::std::nullopt); +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_fused_sdp_choice_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_fused_sdp_choice_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..deb57d7480b15d3b82fd79582eec1c42519a7f56 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_fused_sdp_choice_ops.h @@ -0,0 +1,29 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API _fused_sdp_choice { + using schema = int64_t (const at::Tensor &, const at::Tensor &, const at::Tensor &, const ::std::optional &, double, bool, ::std::optional, bool); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_fused_sdp_choice"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "_fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int"; + static int64_t call(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa); + static int64_t redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_is_all_true.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_is_all_true.h new file mode 100644 index 0000000000000000000000000000000000000000..dabb720c1cbb930504dd6efbf85fe80e2788cc52 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_is_all_true.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::_is_all_true(Tensor self) -> Tensor +inline at::Tensor _is_all_true(const at::Tensor & self) { + return at::_ops::_is_all_true::call(self); +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_safe_softmax_native.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_safe_softmax_native.h new file mode 100644 index 0000000000000000000000000000000000000000..47de636aa39e01096251d7d30b535a255b62dabc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_safe_softmax_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor _safe_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt); +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..98da0bbcf379d819f884e34ef483207ef8613eda --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor _sparse_bsc_tensor_unsafe(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}); +TORCH_API at::Tensor _sparse_bsc_tensor_unsafe(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_test_optional_floatlist_compositeexplicitautograd_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_test_optional_floatlist_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..a7a24a1f42837a425011a6f93d1d073b0afc54fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_test_optional_floatlist_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & _test_optional_floatlist_out(at::Tensor & out, const at::Tensor & values, ::std::optional> addends); +TORCH_API at::Tensor & _test_optional_floatlist_outf(const at::Tensor & values, ::std::optional> addends, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_to_sparse_csc_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_to_sparse_csc_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a6b5dfa6b4d9912cd80ffd78c6d1049383da240f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_to_sparse_csc_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API _to_sparse_csc { + using schema = at::Tensor (const at::Tensor &, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_to_sparse_csc"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor"; + static at::Tensor call(const at::Tensor & self, ::std::optional dense_dim); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim); +}; + +struct TORCH_API _to_sparse_csc_out { + using schema = at::Tensor & (const at::Tensor &, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_to_sparse_csc"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d9d0139be0eb2077b1ac6ff83720b2f8093aed6d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple _to_sparse_semi_structured(const at::Tensor & dense); + +} // namespace cuda +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_transform_bias_rescale_qkv_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_transform_bias_rescale_qkv_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..1ada580cfe88e53e7d0df98689eb0c943272844b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_transform_bias_rescale_qkv_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API _transform_bias_rescale_qkv { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, int64_t); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_transform_bias_rescale_qkv"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)"; + static ::std::tuple call(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads); +}; + +struct TORCH_API _transform_bias_rescale_qkv_out { + using schema = ::std::tuple (const at::Tensor &, const at::Tensor &, int64_t, at::Tensor &, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_transform_bias_rescale_qkv"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))"; + static ::std::tuple call(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f57c06bfd59ac248fa8ba3e178cb5e7a8b6afeea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor _transformer_encoder_layer_fwd(const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask={}, ::std::optional mask_type=::std::nullopt); + +} // namespace cuda +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/align_as.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/align_as.h new file mode 100644 index 0000000000000000000000000000000000000000..a043a05e831d857701821bbab9a289bec77ed2b9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/align_as.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/arcsin_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/arcsin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0606c9ad3dba75daf1431f69583bed0d06310678 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/arcsin_ops.h @@ -0,0 +1,51 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API arcsin { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsin"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arcsin(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API arcsin_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsin_"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "arcsin_(Tensor(a!) self) -> Tensor(a!)"; + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API arcsin_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::arcsin"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/block_diag_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/block_diag_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0e0083d39d0b5977554d96c7ff3f4ef4705e9564 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/block_diag_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API block_diag { + using schema = at::Tensor (at::TensorList); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::block_diag"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "block_diag(Tensor[] tensors) -> Tensor"; + static at::Tensor call(at::TensorList tensors); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors); +}; + +struct TORCH_API block_diag_out { + using schema = at::Tensor & (at::TensorList, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::block_diag"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(at::TensorList tensors, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/ceil_meta.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/ceil_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..4578107262e3abca18d8bf2f5653ca4e28f26631 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/ceil_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_ceil : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/fft_rfftn.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/fft_rfftn.h new file mode 100644 index 0000000000000000000000000000000000000000..43b855c315ad4e8e6648269be62d56a37f403df3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/fft_rfftn.h @@ -0,0 +1,92 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor +inline at::Tensor fft_rfftn(const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::call(self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); +} +namespace symint { + template >> + at::Tensor fft_rfftn(const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::call(self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } +} + +// aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor +inline at::Tensor fft_rfftn_symint(const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::call(self, s, dim, norm); +} +namespace symint { + template >> + at::Tensor fft_rfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::call(self, s, dim, norm); + } +} + +// aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & fft_rfftn_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::call(self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); +} +namespace symint { + template >> + at::Tensor & fft_rfftn_out(at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::call(self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } +} + +// aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & fft_rfftn_outf(const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::call(self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); +} +namespace symint { + template >> + at::Tensor & fft_rfftn_outf(const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::call(self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } +} + +// aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & fft_rfftn_symint_out(at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::call(self, s, dim, norm, out); +} +namespace symint { + template >> + at::Tensor & fft_rfftn_out(at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::call(self, s, dim, norm, out); + } +} + +// aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & fft_rfftn_symint_outf(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::call(self, s, dim, norm, out); +} +namespace symint { + template >> + at::Tensor & fft_rfftn_outf(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::call(self, s, dim, norm, out); + } +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/hardshrink.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/hardshrink.h new file mode 100644 index 0000000000000000000000000000000000000000..74d2b9b15fccd9f3831744cc34bb3e421c89dfef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/hardshrink.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & hardshrink_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::hardshrink_out::call(self, lambd, out); +} +// aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & hardshrink_outf(const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out) { + return at::_ops::hardshrink_out::call(self, lambd, out); +} + +// aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor +inline at::Tensor hardshrink(const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::hardshrink::call(self, lambd); +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_inv_ex_cuda_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_inv_ex_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ca902f8f5b869ea54fc1e7a13324af3b4d779f7a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_inv_ex_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple linalg_inv_ex(const at::Tensor & A, bool check_errors=false); +TORCH_API ::std::tuple linalg_inv_ex_out(at::Tensor & inverse, at::Tensor & info, const at::Tensor & A, bool check_errors=false); +TORCH_API ::std::tuple linalg_inv_ex_outf(const at::Tensor & A, bool check_errors, at::Tensor & inverse, at::Tensor & info); + +} // namespace cuda +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ba81a25fbc914fa1837ccfa424aa4be48f54258d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API ::std::tuple linalg_ldl_factor_ex(const at::Tensor & self, bool hermitian=false, bool check_errors=false); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_vector_norm_meta_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_vector_norm_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3b0e937590115b9f71a33f40f55a56f765d18349 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_vector_norm_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor linalg_vector_norm(const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt); +TORCH_API at::Tensor & linalg_vector_norm_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt); +TORCH_API at::Tensor & linalg_vector_norm_outf(const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); + +} // namespace meta +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/log2_cpu_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/log2_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f885a233fbed2c46b8d73c615b32dc327be5cae5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/log2_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor log2(const at::Tensor & self); +TORCH_API at::Tensor & log2_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & log2_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & log2_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/quantize_per_channel_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/quantize_per_channel_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a68179301e398f253d253978cdc70526177d4919 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/quantize_per_channel_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API quantize_per_channel { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, int64_t, at::ScalarType); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::quantize_per_channel"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype); +}; + +struct TORCH_API quantize_per_channel_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, int64_t, at::ScalarType, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::quantize_per_channel"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/scatter_reduce_cuda_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/scatter_reduce_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..2add648340654ab66e33c3e5775ee3e6a080966b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/scatter_reduce_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor scatter_reduce(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true); +TORCH_API at::Tensor & scatter_reduce_out(at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true); +TORCH_API at::Tensor & scatter_reduce_outf(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self, at::Tensor & out); +TORCH_API at::Tensor & scatter_reduce_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true); + +} // namespace cuda +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/smooth_l1_loss_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/smooth_l1_loss_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..d14eafcc11b21108cf1ee88196083944ae80d8a9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/smooth_l1_loss_ops.h @@ -0,0 +1,40 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API smooth_l1_loss_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, int64_t, double, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::smooth_l1_loss"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & out); +}; + +struct TORCH_API smooth_l1_loss { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, int64_t, double); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::smooth_l1_loss"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor"; + static at::Tensor call(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/softplus_cpu_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/softplus_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d10ef583e6134628a5b69069c79c9c6a2c4401a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/softplus_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor softplus(const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20); +TORCH_API at::Tensor & softplus_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20); +TORCH_API at::Tensor & softplus_outf(const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/special_i0e_cuda_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/special_i0e_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..eed4e1cff11d87ac33d0b6eb36bddfa644c45f31 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/special_i0e_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor special_i0e(const at::Tensor & self); +TORCH_API at::Tensor & special_i0e_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & special_i0e_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/swapdims_compositeimplicitautograd_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/swapdims_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f28095cff8865335712fdbfcdd4c06dfc57e5543 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/swapdims_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor swapdims(const at::Tensor & self, int64_t dim0, int64_t dim1); +TORCH_API at::Tensor & swapdims_(at::Tensor & self, int64_t dim0, int64_t dim1); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/topk.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/topk.h new file mode 100644 index 0000000000000000000000000000000000000000..fa4058ca61c31bcbffc4ec982b94cb1e334fa036 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/topk.h @@ -0,0 +1,92 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) +inline ::std::tuple topk_out(at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); +} +namespace symint { + template >> + ::std::tuple topk_out(at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); + } +} + +// aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) +inline ::std::tuple topk_outf(const at::Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); +} +namespace symint { + template >> + ::std::tuple topk_outf(const at::Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); + } +} + +// aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) +inline ::std::tuple topk_symint_out(at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); +} +namespace symint { + template >> + ::std::tuple topk_out(at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); + } +} + +// aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) +inline ::std::tuple topk_symint_outf(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); +} +namespace symint { + template >> + ::std::tuple topk_outf(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::call(self, k, dim, largest, sorted, values, indices); + } +} + +// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) +inline ::std::tuple topk(const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::call(self, k, dim, largest, sorted); +} +namespace symint { + template >> + ::std::tuple topk(const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::call(self, k, dim, largest, sorted); + } +} + +// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) +inline ::std::tuple topk_symint(const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::call(self, k, dim, largest, sorted); +} +namespace symint { + template >> + ::std::tuple topk(const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::call(self, k, dim, largest, sorted); + } +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/unfold_copy_compositeexplicitautograd_dispatch.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/unfold_copy_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2b2d72d1002bb2d6cbc98e11f3d9e0c87eb29c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/unfold_copy_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor & unfold_copy_out(at::Tensor & out, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step); +TORCH_API at::Tensor & unfold_copy_outf(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step, at::Tensor & out); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_nearest2d_ops.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_nearest2d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..e2b4b627acaf6cc0b1fdf0b7415c2d3d4d91d603 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_nearest2d_ops.h @@ -0,0 +1,62 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API upsample_nearest2d_vec { + using schema = at::Tensor (const at::Tensor &, at::OptionalSymIntArrayRef, ::std::optional>); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::upsample_nearest2d"; + static constexpr const char* overload_name = "vec"; + static constexpr const char* schema_str = "upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor"; + static at::Tensor call(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); +}; + +struct TORCH_API upsample_nearest2d_out { + using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, ::std::optional, ::std::optional, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::upsample_nearest2d"; + static constexpr const char* overload_name = "out"; + static constexpr const char* schema_str = "upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); +}; + +struct TORCH_API upsample_nearest2d { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef, ::std::optional, ::std::optional); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::upsample_nearest2d"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor"; + static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w); +}; + +struct TORCH_API upsample_nearest2d_vec_out { + using schema = at::Tensor & (const at::Tensor &, at::OptionalSymIntArrayRef, ::std::optional>, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::upsample_nearest2d"; + static constexpr const char* overload_name = "vec_out"; + static constexpr const char* schema_str = "upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)"; + static at::Tensor & call(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_nearest3d.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_nearest3d.h new file mode 100644 index 0000000000000000000000000000000000000000..2927b119aeaff83aa4cd2225b9f7eaf1f7a72000 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_nearest3d.h @@ -0,0 +1,114 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor +inline at::Tensor upsample_nearest3d(const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::call(input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); +} +namespace symint { + template >> + at::Tensor upsample_nearest3d(const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::call(input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } +} + +// aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor +inline at::Tensor upsample_nearest3d_symint(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::call(input, output_size, scale_factors); +} +namespace symint { + template >> + at::Tensor upsample_nearest3d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::call(input, output_size, scale_factors); + } +} + +// aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & upsample_nearest3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::call(self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); +} +namespace symint { + template >> + at::Tensor & upsample_nearest3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::call(self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } +} + +// aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & upsample_nearest3d_outf(const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::call(self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); +} +namespace symint { + template >> + at::Tensor & upsample_nearest3d_outf(const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::call(self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } +} + +// aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & upsample_nearest3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::call(self, output_size, scales_d, scales_h, scales_w, out); +} +namespace symint { + template >> + at::Tensor & upsample_nearest3d_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::call(self, output_size, scales_d, scales_h, scales_w, out); + } +} + +// aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & upsample_nearest3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::call(self, output_size, scales_d, scales_h, scales_w, out); +} +namespace symint { + template >> + at::Tensor & upsample_nearest3d_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::call(self, output_size, scales_d, scales_h, scales_w, out); + } +} + +// aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor +inline at::Tensor upsample_nearest3d(const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::call(self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w); +} +namespace symint { + template >> + at::Tensor upsample_nearest3d(const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::call(self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w); + } +} + +// aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor +inline at::Tensor upsample_nearest3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::call(self, output_size, scales_d, scales_h, scales_w); +} +namespace symint { + template >> + at::Tensor upsample_nearest3d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::call(self, output_size, scales_d, scales_h, scales_w); + } +} + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/CachingHostAllocator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/CachingHostAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..ac99e6eef4f44b3aafa1b79300974626aff4e681 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/CachingHostAllocator.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::xpu { + +C10_DEPRECATED_MESSAGE( + "at::xpu::getCachingHostAllocator() is deprecated. Please use at::getHostAllocator(at::kXPU) instead.") +inline TORCH_XPU_API at::HostAllocator* getCachingHostAllocator() { + return at::getHostAllocator(at::kXPU); +} + +C10_DEPRECATED_MESSAGE( + "at::xpu::CachingHostAllocator_recordEvent(...) is deprecated. Please use at::getHostAllocator(at::kXPU)->record_event(...) instead.") +inline TORCH_XPU_API bool CachingHostAllocator_recordEvent( + void* ptr, + void* ctx, + c10::xpu::XPUStream stream) { + return getHostAllocator(at::kXPU)->record_event(ptr, ctx, stream.unwrap()); +} + +C10_DEPRECATED_MESSAGE( + "at::xpu::CachingHostAllocator_emptyCache() is deprecated. Please use at::getHostAllocator(at::kXPU)->empty_cache() instead.") +inline TORCH_XPU_API void CachingHostAllocator_emptyCache() { + getHostAllocator(at::kXPU)->empty_cache(); +} + +C10_DEPRECATED_MESSAGE( + "at::xpu::HostAlloc(...) is deprecated. Please use at::getHostAllocator(at::kXPU)->allocate(...) instead.") +inline TORCH_XPU_API at::DataPtr HostAlloc(size_t size) { + return getHostAllocator(at::kXPU)->allocate(size); +} + +} // namespace at::xpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PinnedMemoryAllocator.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PinnedMemoryAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..25f1ccf05a8e1566cd9ee3f40bf995c02877ee1b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PinnedMemoryAllocator.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +namespace at::xpu { + +inline TORCH_XPU_API at::HostAllocator* getPinnedMemoryAllocator() { + return at::getHostAllocator(at::kXPU); +} +} // namespace at::xpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUContext.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUContext.h new file mode 100644 index 0000000000000000000000000000000000000000..fb8fbe9c0aa4221a3384d6eb7c457d8dad54d0f0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUContext.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include + +namespace at::xpu { + +// XPU is available if we compiled with XPU. +inline bool is_available() { + return c10::xpu::device_count() > 0; +} + +TORCH_XPU_API DeviceProp* getCurrentDeviceProperties(); + +TORCH_XPU_API DeviceProp* getDeviceProperties(DeviceIndex device); + +TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device); + +} // namespace at::xpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUDevice.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUDevice.h new file mode 100644 index 0000000000000000000000000000000000000000..d4ab7187513c15ad0bbc8cff610ec3831f3a51fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUDevice.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +namespace at::xpu { + +inline Device getDeviceFromPtr(void* ptr) { + auto device = c10::xpu::get_device_idx_from_pointer(ptr); + return {c10::DeviceType::XPU, device}; +} + +} // namespace at::xpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUEvent.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..ededd6ebf4f15b0c132743b0125cfd0c7c95e421 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUEvent.h @@ -0,0 +1,191 @@ +#pragma once +#include + +#include + +namespace at::xpu { + +/* + * XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are + * constructed lazily when first recorded. It has a device, and this device is + * acquired from the first recording stream. Later streams that record the event + * must match the same device. + * + * Currently, XPUEvent does NOT support to export an inter-process event from + * another process via inter-process comunication(IPC). So it means that + * inter-process communication for event handles between different processes is + * not available. This could impact some applications that rely on cross-process + * synchronization and communication. + */ +struct TORCH_XPU_API XPUEvent { + // Constructors + XPUEvent(bool enable_timing = false) noexcept + : enable_timing_{enable_timing} {} + + ~XPUEvent() { + if (isCreated()) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + at::kXPU, reinterpret_cast(event_.get())); + } + } + } + + XPUEvent(const XPUEvent&) = delete; + XPUEvent& operator=(const XPUEvent&) = delete; + + XPUEvent(XPUEvent&& other) = default; + XPUEvent& operator=(XPUEvent&& other) = default; + + operator sycl::event&() const { + return event(); + } + + std::optional device() const { + if (isCreated()) { + return at::Device(at::kXPU, device_index_); + } else { + return std::nullopt; + } + } + + inline bool isCreated() const { + return (event_.get() != nullptr); + } + + DeviceIndex device_index() const { + return device_index_; + } + + sycl::event& event() const { + return *event_; + } + + bool query() const { + using namespace sycl::info; + if (!isCreated()) { + return true; + } + + return event().get_info() == + event_command_status::complete; + } + + void record() { + record(getCurrentXPUStream()); + } + + void recordOnce(const XPUStream& stream) { + if (!isCreated()) { + record(stream); + } + } + + void record(const XPUStream& stream) { + if (!isCreated()) { + device_index_ = stream.device_index(); + assignEvent(stream.queue()); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation( + at::kXPU, reinterpret_cast(event_.get())); + } + } else { + TORCH_CHECK( + device_index_ == stream.device_index(), + "Event device ", + device_index_, + " does not match recording stream's device ", + stream.device_index(), + "."); + reassignEvent(stream.queue()); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + at::kXPU, + reinterpret_cast(event_.get()), + reinterpret_cast(&stream.queue())); + } + } + + void block(const XPUStream& stream) { + if (isCreated()) { + std::vector event_list{event()}; + // Make this stream wait until event_ is completed. + stream.queue().ext_oneapi_submit_barrier(event_list); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + at::kXPU, + reinterpret_cast(event_.get()), + reinterpret_cast(&stream.queue())); + } + } + } + + double elapsed_time(const XPUEvent& other) const { + TORCH_CHECK( + isCreated() && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + TORCH_CHECK( + enable_timing_ && other.enable_timing_, + "Both events must be created with argument 'enable_timing=True'."); + +#if SYCL_COMPILER_VERSION < 20250000 + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer."); +#endif + + using namespace sycl::info::event_profiling; + // Block until both of the recorded events are completed. + uint64_t end_time_ns = other.event().get_profiling_info(); + uint64_t start_time_ns = event().get_profiling_info(); + // Return the eplased time in milliseconds. + return 1e-6 * + (static_cast(end_time_ns) - static_cast(start_time_ns)); + } + + void synchronize() const { + if (isCreated()) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization( + at::kXPU, reinterpret_cast(event_.get())); + } + event().wait_and_throw(); + } + } + + private: + void assignEvent(sycl::queue& queue) { +#if SYCL_COMPILER_VERSION >= 20250000 + if (enable_timing_) { + event_ = std::make_unique( + sycl::ext::oneapi::experimental::submit_profiling_tag(queue)); + } else { + event_ = std::make_unique(queue.ext_oneapi_submit_barrier()); + } +#else + event_ = std::make_unique(queue.ext_oneapi_submit_barrier()); +#endif + } + + void reassignEvent(sycl::queue& queue) { + event_.reset(); + assignEvent(queue); + } + + bool enable_timing_ = false; + DeviceIndex device_index_ = -1; + // Only need to track the last event, as events in an in-order queue are + // executed sequentially. + std::unique_ptr event_; +}; + +} // namespace at::xpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGeneratorImpl.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..a1f264382a366d70410adc7c98280694b63d6ba5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGeneratorImpl.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +namespace at { + +struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { + // Constructors + XPUGeneratorImpl(DeviceIndex device_index = -1); + ~XPUGeneratorImpl() override = default; + + // XPUGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void set_philox_offset_per_thread(uint64_t offset); + uint64_t philox_offset_per_thread() const; + std::pair philox_engine_inputs(uint64_t increment); + static c10::DeviceType device_type(); + + private: + XPUGeneratorImpl* clone_impl() const override; + uint64_t seed_ = default_rng_seed_val; + uint64_t philox_offset_per_thread_ = 0; +}; + +namespace xpu::detail { + +TORCH_XPU_API const Generator& getDefaultXPUGenerator(DeviceIndex device = -1); + +TORCH_XPU_API Generator createXPUGenerator(DeviceIndex device = -1); + +} // namespace xpu::detail +} // namespace at diff --git a/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/detail/XPUHooks.h b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/detail/XPUHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..1103b5b945662c963f8fa7d373df2ede4c25f74d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/detail/XPUHooks.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +namespace at::xpu::detail { + +// The real implementation of XPUHooksInterface +struct XPUHooks : public at::XPUHooksInterface { + XPUHooks(at::XPUHooksArgs) {} + void init() const override; + bool hasXPU() const override; + std::string showConfig() const override; + int32_t getGlobalIdxFromDevice(const at::Device& device) const override; + const Generator& getDefaultGenerator( + DeviceIndex device_index = -1) const override; + Generator getNewGenerator(DeviceIndex device_index = -1) const override; + Device getDeviceFromPtr(void* data) const override; + c10::DeviceIndex getNumGPUs() const override; + DeviceIndex current_device() const override; + void deviceSynchronize(DeviceIndex device_index) const override; + Allocator* getPinnedMemoryAllocator() const override; + + bool isBuilt() const override { + return true; + } + bool isAvailable() const override; + bool isPinnedPtr(const void* data) const override; + bool hasPrimaryContext(DeviceIndex device_index) const override; + DeviceIndex deviceCount() const override; + DeviceIndex getCurrentDevice() const override; +}; + +} // namespace at::xpu::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64assembler.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64assembler.h new file mode 100644 index 0000000000000000000000000000000000000000..319321576d85dd8e95f72b98a93fc1bffa225ad3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64assembler.h @@ -0,0 +1,61 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64ASSEMBLER_H_INCLUDED +#define ASMJIT_ARM_A64ASSEMBLER_H_INCLUDED + +#include "../core/assembler.h" +#include "../arm/a64emitter.h" +#include "../arm/a64operand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 assembler implementation. +class ASMJIT_VIRTAPI Assembler + : public BaseAssembler, + public EmitterExplicitT { + +public: + typedef BaseAssembler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API Assembler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Assembler() noexcept override; + + //! \} + + //! \name Emit + //! \{ + + ASMJIT_API Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* opExt) override; + + //! \} + + //! \name Align + //! \{ + + ASMJIT_API Error align(AlignMode alignMode, uint32_t alignment) override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_A64ASSEMBLER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64builder.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64builder.h new file mode 100644 index 0000000000000000000000000000000000000000..cab1083172b3d8e7ab23eb7fa992ece23c2f2539 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64builder.h @@ -0,0 +1,57 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64BUILDER_H_INCLUDED +#define ASMJIT_ARM_A64BUILDER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_BUILDER + +#include "../core/builder.h" +#include "../arm/a64emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 builder implementation. +class ASMJIT_VIRTAPI Builder + : public BaseBuilder, + public EmitterExplicitT { +public: + ASMJIT_NONCOPYABLE(Builder) + typedef BaseBuilder Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Builder(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Builder() noexcept override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_BUILDER +#endif // ASMJIT_ARM_A64BUILDER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64compiler.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..64f82f62538ec1538e57d096a49065ba5424b50a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64compiler.h @@ -0,0 +1,254 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64COMPILER_H_INCLUDED +#define ASMJIT_ARM_A64COMPILER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_COMPILER + +#include "../core/compiler.h" +#include "../core/type.h" +#include "../arm/a64emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 compiler implementation. +class ASMJIT_VIRTAPI Compiler + : public BaseCompiler, + public EmitterExplicitT { +public: + ASMJIT_NONCOPYABLE(Compiler) + typedef BaseCompiler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Compiler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Compiler() noexcept override; + + //! \} + + //! \name Virtual Registers + //! \{ + + //! \cond INTERNAL + template + ASMJIT_INLINE_NODEBUG RegT _newRegInternal(const Type& type) { + RegT reg(Globals::NoInit); + _newReg(®, type, nullptr); + return reg; + } + + template + ASMJIT_INLINE_NODEBUG RegT _newRegInternal(const Type& type, const char* s) { +#ifndef ASMJIT_NO_LOGGING + RegT reg(Globals::NoInit); + _newReg(®, type, s); + return reg; +#else + DebugUtils::unused(s); + return _newRegInternal(type); +#endif + } + + template + ASMJIT_INLINE_NODEBUG RegT _newRegInternal(const Type& type, const char* s, Args&&... args) { +#ifndef ASMJIT_NO_LOGGING + RegT reg(Globals::NoInit); + _newRegFmt(®, type, s, std::forward(args)...); + return reg; +#else + DebugUtils::unused(s, std::forward(args)...); + return _newRegInternal(type); +#endif + } + //! \endcond + + template + ASMJIT_INLINE_NODEBUG RegT newSimilarReg(const RegT& ref, Args&&... args) { + return _newRegInternal(ref, std::forward(args)...); + } + + template + ASMJIT_INLINE_NODEBUG Reg newReg(TypeId typeId, Args&&... args) { return _newRegInternal(typeId, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newGp(TypeId typeId, Args&&... args) { return _newRegInternal(typeId, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVec(TypeId typeId, Args&&... args) { return _newRegInternal(typeId, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newInt32(Args&&... args) { return _newRegInternal(TypeId::kInt32, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newUInt32(Args&&... args) { return _newRegInternal(TypeId::kUInt32, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newInt64(Args&&... args) { return _newRegInternal(TypeId::kInt64, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newUInt64(Args&&... args) { return _newRegInternal(TypeId::kUInt64, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newIntPtr(Args&&... args) { return _newRegInternal(TypeId::kIntPtr, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newUIntPtr(Args&&... args) { return _newRegInternal(TypeId::kUIntPtr, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newGpw(Args&&... args) { return _newRegInternal(TypeId::kUInt32, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newGpx(Args&&... args) { return _newRegInternal(TypeId::kUInt64, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newGpz(Args&&... args) { return _newRegInternal(TypeId::kUIntPtr, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVecS(Args&&... args) { return _newRegInternal(TypeId::kFloat32, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVecD(Args&&... args) { return _newRegInternal(TypeId::kFloat64, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVecQ(Args&&... args) { return _newRegInternal(TypeId::kUInt8x16, std::forward(args)...); } + + //! \} + + //! \name Stack + //! \{ + + //! Creates a new memory chunk allocated on the current function's stack. + ASMJIT_INLINE_NODEBUG Mem newStack(uint32_t size, uint32_t alignment, const char* name = nullptr) { + Mem m(Globals::NoInit); + _newStack(&m, size, alignment, name); + return m; + } + + //! \} + + //! \name Constants + //! \{ + + //! Put data to a constant-pool and get a memory reference to it. + ASMJIT_INLINE_NODEBUG Mem newConst(ConstPoolScope scope, const void* data, size_t size) { + Mem m(Globals::NoInit); + _newConst(&m, scope, data, size); + return m; + } + + //! Put a BYTE `val` to a constant-pool (8 bits). + ASMJIT_INLINE_NODEBUG Mem newByteConst(ConstPoolScope scope, uint8_t val) noexcept { return newConst(scope, &val, 1); } + //! Put a HWORD `val` to a constant-pool (16 bits). + ASMJIT_INLINE_NODEBUG Mem newHWordConst(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a WORD `val` to a constant-pool (32 bits). + ASMJIT_INLINE_NODEBUG Mem newWordConst(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a DWORD `val` to a constant-pool (64 bits). + ASMJIT_INLINE_NODEBUG Mem newDWordConst(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt16Const(ConstPoolScope scope, int16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt16Const(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt32Const(ConstPoolScope scope, int32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt32Const(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt64Const(ConstPoolScope scope, int64_t val) noexcept { return newConst(scope, &val, 8); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt64Const(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a SP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newFloatConst(ConstPoolScope scope, float val) noexcept { return newConst(scope, &val, 4); } + //! Put a DP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newDoubleConst(ConstPoolScope scope, double val) noexcept { return newConst(scope, &val, 8); } + + //! \} + + //! \name Instruction Options + //! \{ + + //! Force the compiler to not follow the conditional or unconditional jump. + ASMJIT_INLINE_NODEBUG Compiler& unfollow() noexcept { _instOptions |= InstOptions::kUnfollow; return *this; } + + //! \} + + //! \name Compiler specific + //! \{ + + //! Special pseudo-instruction that can be used to load a memory address into `o0` GP register. + //! + //! \note At the moment this instruction is only useful to load a stack allocated address into a GP register + //! for further use. It makes very little sense to use it for anything else. The semantics of this instruction + //! is the same as X86 `LEA` (load effective address) instruction. + ASMJIT_INLINE_NODEBUG Error loadAddressOf(const Gp& o0, const Mem& o1) { return _emitter()->_emitI(Inst::kIdAdr, o0, o1); } + + //! \} + + //! \name Function Call & Ret Intrinsics + //! \{ + + //! Invoke a function call without `target` type enforcement. + ASMJIT_INLINE_NODEBUG Error invoke_(InvokeNode** out, const Operand_& target, const FuncSignature& signature) { + return addInvokeNode(out, Inst::kIdBlr, target, signature); + } + + //! Invoke a function call of the given `target` and `signature` and store the added node to `out`. + //! + //! Creates a new \ref InvokeNode, initializes all the necessary members to match the given function `signature`, + //! adds the node to the compiler, and stores its pointer to `out`. The operation is atomic, if anything fails + //! nullptr is stored in `out` and error code is returned. + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Gp& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Mem& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Label& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Imm& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, uint64_t target, const FuncSignature& signature) { return invoke_(out, Imm(int64_t(target)), signature); } + + //! Return. + ASMJIT_INLINE_NODEBUG Error ret() { return addRet(Operand(), Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0) { return addRet(o0, Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0, const BaseReg& o1) { return addRet(o0, o1); } + + //! \} + + //! \name Jump Tables Support + //! \{ + + using EmitterExplicitT::br; + + //! Adds a jump to the given `target` with the provided jump `annotation`. + ASMJIT_INLINE_NODEBUG Error br(const BaseReg& target, JumpAnnotation* annotation) { return emitAnnotatedJump(Inst::kIdBr, target, annotation); } + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_COMPILER +#endif // ASMJIT_ARM_A64COMPILER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64emitter.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..43484344526a8169c12bdc9e9c494d3a57a4b5a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64emitter.h @@ -0,0 +1,1232 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64EMITTER_H_INCLUDED +#define ASMJIT_ARM_A64EMITTER_H_INCLUDED + +#include "../core/emitter.h" +#include "../core/support.h" +#include "../arm/a64instdb.h" +#include "../arm/a64operand.h" + +// MSVC targeting AArch64 defines a lot of macros without underscores clashing +// with AArch64 instruction names. We have to workaround until it's fixed in SDK. +#if defined(_MSC_VER) && defined(mvn) + #define ASMJIT_RESTORE_MSVC_AARCH64_MACROS + #pragma push_macro("mvn") + #undef mvn +#endif + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +#define ASMJIT_INST_0x(NAME, ID) \ + inline Error NAME() { return _emitter()->_emitI(Inst::kId##ID); } + +#define ASMJIT_INST_1x(NAME, ID, T0) \ + inline Error NAME(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID, o0); } + +#define ASMJIT_INST_2x(NAME, ID, T0, T1) \ + inline Error NAME(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID, o0, o1); } + +#define ASMJIT_INST_3x(NAME, ID, T0, T1, T2) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2); } + +#define ASMJIT_INST_4x(NAME, ID, T0, T1, T2, T3) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3); } + +#define ASMJIT_INST_5x(NAME, ID, T0, T1, T2, T3, T4) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4); } + +#define ASMJIT_INST_6x(NAME, ID, T0, T1, T2, T3, T4, T5) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4, const T5& o5) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4, o5); } + +#define ASMJIT_INST_1cc(NAME, ID, T0) \ + inline Error NAME(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID, o0); } \ + \ + inline Error NAME(CondCode cc, const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, cc), o0); } \ + \ + inline Error NAME##_eq(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kEQ), o0); } \ + inline Error NAME##_ne(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kNE), o0); } \ + inline Error NAME##_cs(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kCS), o0); } \ + inline Error NAME##_hs(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kHS), o0); } \ + inline Error NAME##_cc(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kCC), o0); } \ + inline Error NAME##_lo(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLO), o0); } \ + inline Error NAME##_mi(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kMI), o0); } \ + inline Error NAME##_pl(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kPL), o0); } \ + inline Error NAME##_vs(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kVS), o0); } \ + inline Error NAME##_vc(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kVC), o0); } \ + inline Error NAME##_hi(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kHI), o0); } \ + inline Error NAME##_ls(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLS), o0); } \ + inline Error NAME##_ge(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kGE), o0); } \ + inline Error NAME##_lt(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLT), o0); } \ + inline Error NAME##_gt(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kGT), o0); } \ + inline Error NAME##_le(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLE), o0); } \ + inline Error NAME##_al(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kAL), o0); } + +//! \addtogroup asmjit_a64 +//! \{ + +//! ARM emitter. +//! +//! NOTE: This class cannot be instantiated, you can only cast to it and use it as emitter that emits to either +//! \ref Assembler, \ref Builder, or \ref Compiler (use with caution with \ref Compiler as it expects virtual +//! registers to be used). +template +struct EmitterExplicitT { + //! \cond + + // These two are unfortunately reported by the sanitizer. We know what we do, however, the sanitizer doesn't. + // I have tried to use reinterpret_cast instead, but that would generate bad code when compiled by MSC. + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG This* _emitter() noexcept { return static_cast(this); } + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG const This* _emitter() const noexcept { return static_cast(this); } + + //! \endcond + + //! \name General Purpose Instructions + //! \{ + + ASMJIT_INST_3x(adc, Adc, Gp, Gp, Gp) + ASMJIT_INST_3x(adcs, Adcs, Gp, Gp, Gp) + + ASMJIT_INST_3x(add, Add, Gp, Gp, Gp) + ASMJIT_INST_4x(add, Add, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(add, Add, Gp, Gp, Imm) + ASMJIT_INST_4x(add, Add, Gp, Gp, Imm, Imm) + ASMJIT_INST_3x(adds, Adds, Gp, Gp, Gp) + ASMJIT_INST_3x(adds, Adds, Gp, Gp, Imm) + ASMJIT_INST_4x(adds, Adds, Gp, Gp, Gp, Imm) + ASMJIT_INST_4x(adds, Adds, Gp, Gp, Imm, Imm) + + ASMJIT_INST_2x(adr, Adr, Gp, Imm) + ASMJIT_INST_2x(adr, Adr, Gp, Label) + ASMJIT_INST_2x(adrp, Adrp, Gp, Imm) + ASMJIT_INST_2x(adrp, Adrp, Gp, Label) + + ASMJIT_INST_3x(and_, And, Gp, Gp, Imm) + ASMJIT_INST_3x(and_, And, Gp, Gp, Gp) + ASMJIT_INST_4x(and_, And, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(ands, Ands, Gp, Gp, Imm) + ASMJIT_INST_3x(ands, Ands, Gp, Gp, Gp) + ASMJIT_INST_4x(ands, Ands, Gp, Gp, Gp, Imm) + + ASMJIT_INST_3x(asr, Asr, Gp, Gp, Imm) + ASMJIT_INST_3x(asr, Asr, Gp, Gp, Gp) + ASMJIT_INST_3x(asrv, Asrv, Gp, Gp, Gp) + + ASMJIT_INST_2x(at, At, Imm, Gp) + + ASMJIT_INST_3x(bfc, Bfc, Gp, Imm, Imm) + ASMJIT_INST_4x(bfi, Bfi, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(bfm, Bfm, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(bfxil, Bfxil, Gp, Gp, Imm, Imm) + + ASMJIT_INST_3x(bic, Bic, Gp, Gp, Imm); + ASMJIT_INST_3x(bic, Bic, Gp, Gp, Gp); + ASMJIT_INST_4x(bic, Bic, Gp, Gp, Gp, Imm); + ASMJIT_INST_3x(bics, Bics, Gp, Gp, Imm); + ASMJIT_INST_3x(bics, Bics, Gp, Gp, Gp); + ASMJIT_INST_4x(bics, Bics, Gp, Gp, Gp, Imm); + + ASMJIT_INST_1x(brk, Brk, Imm) + + ASMJIT_INST_4x(ccmn, Ccmn, Gp, Gp, Imm, Imm); + ASMJIT_INST_4x(ccmn, Ccmn, Gp, Imm, Imm, Imm); + ASMJIT_INST_4x(ccmp, Ccmp, Gp, Gp, Imm, Imm); + ASMJIT_INST_4x(ccmp, Ccmp, Gp, Imm, Imm, Imm); + + ASMJIT_INST_3x(cinc, Cinc, Gp, Gp, Imm); + ASMJIT_INST_3x(cinv, Cinv, Gp, Gp, Imm); + + ASMJIT_INST_1x(clrex, Clrex, Imm) + + ASMJIT_INST_2x(cls, Cls, Gp, Gp) + ASMJIT_INST_2x(clz, Clz, Gp, Gp) + + ASMJIT_INST_2x(cmn, Cmn, Gp, Gp) + ASMJIT_INST_3x(cmn, Cmn, Gp, Gp, Imm) + ASMJIT_INST_2x(cmn, Cmn, Gp, Imm) + ASMJIT_INST_3x(cmn, Cmn, Gp, Imm, Imm) + ASMJIT_INST_2x(cmp, Cmp, Gp, Gp) + ASMJIT_INST_3x(cmp, Cmp, Gp, Gp, Imm) + ASMJIT_INST_2x(cmp, Cmp, Gp, Imm) + ASMJIT_INST_3x(cmp, Cmp, Gp, Imm, Imm) + + ASMJIT_INST_3x(cneg, Cneg, Gp, Gp, Imm); + + ASMJIT_INST_4x(csel, Csel, Gp, Gp, Gp, Imm); + ASMJIT_INST_2x(cset, Cset, Gp, Imm); + ASMJIT_INST_2x(csetm, Csetm, Gp, Imm); + + ASMJIT_INST_4x(csinc, Csinc, Gp, Gp, Gp, Imm); + ASMJIT_INST_4x(csinv, Csinv, Gp, Gp, Gp, Imm); + ASMJIT_INST_4x(csneg, Csneg, Gp, Gp, Gp, Imm); + + ASMJIT_INST_2x(dc, Dc, Imm, Gp) + ASMJIT_INST_1x(dmb, Dmb, Imm) + ASMJIT_INST_1x(dsb, Dsb, Imm) + ASMJIT_INST_0x(drps, Drps) + + ASMJIT_INST_3x(eon, Eon, Gp, Gp, Gp) + ASMJIT_INST_4x(eon, Eon, Gp, Gp, Gp, Imm) + + ASMJIT_INST_3x(eor, Eor, Gp, Gp, Imm) + ASMJIT_INST_3x(eor, Eor, Gp, Gp, Gp) + ASMJIT_INST_4x(eor, Eor, Gp, Gp, Gp, Imm) + + ASMJIT_INST_0x(eret, Eret) + ASMJIT_INST_0x(esb, Esb) + + ASMJIT_INST_4x(extr, Extr, Gp, Gp, Gp, Imm) + + ASMJIT_INST_1x(hlt, Hlt, Imm) + ASMJIT_INST_1x(hvc, Hvc, Imm) + ASMJIT_INST_2x(ic, Ic, Imm, Gp) + ASMJIT_INST_1x(isb, Isb, Imm) + + ASMJIT_INST_3x(lsl, Lsl, Gp, Gp, Imm) + ASMJIT_INST_3x(lsl, Lsl, Gp, Gp, Gp) + ASMJIT_INST_3x(lslv, Lslv, Gp, Gp, Gp) + + ASMJIT_INST_3x(lsr, Lsr, Gp, Gp, Imm) + ASMJIT_INST_3x(lsr, Lsr, Gp, Gp, Gp) + ASMJIT_INST_3x(lsrv, Lsrv, Gp, Gp, Gp) + + ASMJIT_INST_4x(madd, Madd, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(mneg, Mneg, Gp, Gp, Gp) + + ASMJIT_INST_2x(mov, Mov, Gp, Gp) + ASMJIT_INST_2x(mov, Mov, Gp, Imm) + ASMJIT_INST_2x(movk, Movk, Gp, Imm) + ASMJIT_INST_3x(movk, Movk, Gp, Imm, Imm) + ASMJIT_INST_2x(movn, Movn, Gp, Imm) + ASMJIT_INST_3x(movn, Movn, Gp, Imm, Imm) + ASMJIT_INST_2x(movz, Movz, Gp, Imm) + ASMJIT_INST_3x(movz, Movz, Gp, Imm, Imm) + + ASMJIT_INST_2x(mrs, Mrs, Gp, Imm) + ASMJIT_INST_2x(msr, Msr, Imm, Gp) + ASMJIT_INST_2x(msr, Msr, Imm, Imm) + + ASMJIT_INST_4x(msub, Msub, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(mul, Mul, Gp, Gp, Gp) + + ASMJIT_INST_2x(mvn, Mvn, Gp, Gp) + ASMJIT_INST_3x(mvn, Mvn, Gp, Gp, Imm) + + ASMJIT_INST_2x(neg, Neg, Gp, Gp) + ASMJIT_INST_3x(neg, Neg, Gp, Gp, Imm) + ASMJIT_INST_2x(negs, Negs, Gp, Gp) + ASMJIT_INST_3x(negs, Negs, Gp, Gp, Imm) + + ASMJIT_INST_2x(ngc, Ngc, Gp, Gp) + ASMJIT_INST_2x(ngcs, Ngcs, Gp, Gp) + + ASMJIT_INST_3x(orn, Orn, Gp, Gp, Gp) + ASMJIT_INST_4x(orn, Orn, Gp, Gp, Gp, Imm) + + ASMJIT_INST_3x(orr, Orr, Gp, Gp, Imm) + ASMJIT_INST_3x(orr, Orr, Gp, Gp, Gp) + ASMJIT_INST_4x(orr, Orr, Gp, Gp, Gp, Imm) + + ASMJIT_INST_2x(rbit, Rbit, Gp, Gp) + ASMJIT_INST_1x(ret, Ret, Gp) + + ASMJIT_INST_2x(rev, Rev, Gp, Gp) + ASMJIT_INST_2x(rev16, Rev16, Gp, Gp) + ASMJIT_INST_2x(rev32, Rev32, Gp, Gp) + ASMJIT_INST_2x(rev64, Rev64, Gp, Gp) + + ASMJIT_INST_3x(ror, Ror, Gp, Gp, Imm) + ASMJIT_INST_3x(ror, Ror, Gp, Gp, Gp) + ASMJIT_INST_3x(rorv, Rorv, Gp, Gp, Gp) + + ASMJIT_INST_3x(sbc, Sbc, Gp, Gp, Gp) + ASMJIT_INST_3x(sbcs, Sbcs, Gp, Gp, Gp) + + ASMJIT_INST_4x(sbfiz, Sbfiz, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(sbfm, Sbfm, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(sbfx, Sbfx, Gp, Gp, Imm, Imm) + + ASMJIT_INST_3x(sdiv, Sdiv, Gp, Gp, Gp) + + ASMJIT_INST_4x(smaddl, Smaddl, Gp, Gp, Gp, Gp) + ASMJIT_INST_1x(smc, Smc, Imm) + ASMJIT_INST_3x(smnegl, Smnegl, Gp, Gp, Gp) + ASMJIT_INST_4x(smsubl, Smsubl, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(smulh, Smulh, Gp, Gp, Gp) + ASMJIT_INST_3x(smull, Smull, Gp, Gp, Gp) + + ASMJIT_INST_3x(sub, Sub, Gp, Gp, Gp) + ASMJIT_INST_4x(sub, Sub, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(sub, Sub, Gp, Gp, Imm) + ASMJIT_INST_4x(sub, Sub, Gp, Gp, Imm, Imm) + ASMJIT_INST_3x(subs, Subs, Gp, Gp, Gp) + ASMJIT_INST_4x(subs, Subs, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(subs, Subs, Gp, Gp, Imm) + ASMJIT_INST_4x(subs, Subs, Gp, Gp, Imm, Imm) + + ASMJIT_INST_1x(svc, Svc, Imm) + + ASMJIT_INST_2x(sxtb, Sxtb, Gp, Gp) + ASMJIT_INST_2x(sxth, Sxth, Gp, Gp) + ASMJIT_INST_2x(sxtw, Sxtw, Gp, Gp) + + ASMJIT_INST_4x(sys, Sys, Imm, Imm, Imm, Imm) + ASMJIT_INST_5x(sys, Sys, Imm, Imm, Imm, Imm, Gp) + + ASMJIT_INST_2x(tlbi, Tlbi, Imm, Gp) + ASMJIT_INST_2x(tst, Tst, Gp, Imm) + ASMJIT_INST_2x(tst, Tst, Gp, Gp) + ASMJIT_INST_3x(tst, Tst, Gp, Gp, Imm) + + ASMJIT_INST_3x(udiv, Udiv, Gp, Gp, Gp) + + ASMJIT_INST_4x(ubfiz, Ubfiz, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(ubfm, Ubfm, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(ubfx, Ubfx, Gp, Gp, Imm, Imm) + + ASMJIT_INST_4x(umaddl, Umaddl, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(umnegl, Umnegl, Gp, Gp, Gp) + ASMJIT_INST_4x(umsubl, Umsubl, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(umull, Umull, Gp, Gp, Gp) + ASMJIT_INST_3x(umulh, Umulh, Gp, Gp, Gp) + + ASMJIT_INST_2x(uxtb, Uxtb, Gp, Gp) + ASMJIT_INST_2x(uxth, Uxth, Gp, Gp) + + ASMJIT_INST_0x(csdb, Csdb) + ASMJIT_INST_1x(dcps1, Dcps1, Imm) + ASMJIT_INST_1x(dcps2, Dcps2, Imm) + ASMJIT_INST_1x(dcps3, Dcps3, Imm) + ASMJIT_INST_0x(dgh, Dgh) + ASMJIT_INST_0x(pssbb, Pssbb) + ASMJIT_INST_0x(ssbb, Ssbb) + ASMJIT_INST_1x(udf, Udf, Imm) + ASMJIT_INST_1x(setf8, Setf8, Gp) + ASMJIT_INST_1x(setf16, Setf16, Gp) + + //! \} + + //! \name ARMv8.4 Instructions + //! \{ + + ASMJIT_INST_0x(cfinv, Cfinv) + + //! \} + + //! \name ARMv8.5 Instructions + //! \{ + + ASMJIT_INST_0x(axflag, Axflag) + ASMJIT_INST_0x(xaflag, Xaflag) + + //! \} + + //! \name Branch Instructions + //! \{ + + ASMJIT_INST_1cc(b, B, Imm) + ASMJIT_INST_1cc(b, B, Label) + ASMJIT_INST_1x(bl, Bl, Imm) + ASMJIT_INST_1x(bl, Bl, Label) + ASMJIT_INST_1x(blr, Blr, Gp) + ASMJIT_INST_1x(br, Br, Gp) + ASMJIT_INST_2x(cbz, Cbz, Gp, Imm) + ASMJIT_INST_2x(cbz, Cbz, Gp, Label) + ASMJIT_INST_2x(cbnz, Cbnz, Gp, Imm) + ASMJIT_INST_2x(cbnz, Cbnz, Gp, Label) + ASMJIT_INST_3x(tbnz, Tbnz, Gp, Imm, Imm) + ASMJIT_INST_3x(tbnz, Tbnz, Gp, Imm, Label) + ASMJIT_INST_3x(tbz, Tbz, Gp, Imm, Imm) + ASMJIT_INST_3x(tbz, Tbz, Gp, Imm, Label) + + //! \} + + //! \name Load & Store Instructions + //! \{ + + ASMJIT_INST_3x(cas, Cas, Gp, Gp, Mem) + ASMJIT_INST_3x(casa, Casa, Gp, Gp, Mem) + ASMJIT_INST_3x(casab, Casab, Gp, Gp, Mem) + ASMJIT_INST_3x(casah, Casah, Gp, Gp, Mem) + ASMJIT_INST_3x(casal, Casal, Gp, Gp, Mem) + ASMJIT_INST_3x(casalb, Casalb, Gp, Gp, Mem) + ASMJIT_INST_3x(casalh, Casalh, Gp, Gp, Mem) + ASMJIT_INST_3x(casb, Casb, Gp, Gp, Mem) + ASMJIT_INST_3x(cash, Cash, Gp, Gp, Mem) + ASMJIT_INST_3x(casl, Casl, Gp, Gp, Mem) + ASMJIT_INST_3x(caslb, Caslb, Gp, Gp, Mem) + ASMJIT_INST_3x(caslh, Caslh, Gp, Gp, Mem) + + ASMJIT_INST_5x(casp, Casp, Gp, Gp, Gp, Gp, Mem) + ASMJIT_INST_5x(caspa, Caspa, Gp, Gp, Gp, Gp, Mem) + ASMJIT_INST_5x(caspal, Caspal, Gp, Gp, Gp, Gp, Mem) + ASMJIT_INST_5x(caspl, Caspl, Gp, Gp, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldadd, Ldadd, Gp, Gp, Mem) + ASMJIT_INST_3x(ldadda, Ldadda, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddab, Ldaddab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddah, Ldaddah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddal, Ldaddal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddalb, Ldaddalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddalh, Ldaddalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddb, Ldaddb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddh, Ldaddh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddl, Ldaddl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddlb, Ldaddlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddlh, Ldaddlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldar, Ldar, Gp, Mem) + ASMJIT_INST_2x(ldarb, Ldarb, Gp, Mem) + ASMJIT_INST_2x(ldarh, Ldarh, Gp, Mem) + + ASMJIT_INST_2x(ldaxr, Ldaxr, Gp, Mem) + ASMJIT_INST_2x(ldaxrb, Ldaxrb, Gp, Mem) + ASMJIT_INST_2x(ldaxrh, Ldaxrh, Gp, Mem) + + ASMJIT_INST_3x(ldclr, Ldclr, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclra, Ldclra, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrab, Ldclrab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrah, Ldclrah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclral, Ldclral, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclralb, Ldclralb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclralh, Ldclralh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrb, Ldclrb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrh, Ldclrh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrl, Ldclrl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrlb, Ldclrlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrlh, Ldclrlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldeor, Ldeor, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeora, Ldeora, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorab, Ldeorab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorah, Ldeorah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeoral, Ldeoral, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeoralb, Ldeoralb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeoralh, Ldeoralh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorb, Ldeorb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorh, Ldeorh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorl, Ldeorl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorlb, Ldeorlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorlh, Ldeorlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldlar, Ldlar, Gp, Mem) + ASMJIT_INST_2x(ldlarb, Ldlarb, Gp, Mem) + ASMJIT_INST_2x(ldlarh, Ldlarh, Gp, Mem) + + ASMJIT_INST_3x(ldnp, Ldnp, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldp, Ldp, Gp, Gp, Mem) + ASMJIT_INST_3x(ldpsw, Ldpsw, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldr, Ldr, Gp, Mem) + ASMJIT_INST_2x(ldrb, Ldrb, Gp, Mem) + ASMJIT_INST_2x(ldrh, Ldrh, Gp, Mem) + ASMJIT_INST_2x(ldrsb, Ldrsb, Gp, Mem) + ASMJIT_INST_2x(ldrsh, Ldrsh, Gp, Mem) + ASMJIT_INST_2x(ldrsw, Ldrsw, Gp, Mem) + + ASMJIT_INST_3x(ldset, Ldset, Gp, Gp, Mem) + ASMJIT_INST_3x(ldseta, Ldseta, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetab, Ldsetab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetah, Ldsetah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetal, Ldsetal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetalb, Ldsetalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetalh, Ldsetalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetb, Ldsetb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldseth, Ldseth, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetl, Ldsetl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetlb, Ldsetlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetlh, Ldsetlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldsmax, Ldsmax, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxa, Ldsmaxa, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxab, Ldsmaxab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxah, Ldsmaxah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxal, Ldsmaxal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxalb, Ldsmaxalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxalh, Ldsmaxalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxb, Ldsmaxb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxh, Ldsmaxh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxl, Ldsmaxl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxlb, Ldsmaxlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxlh, Ldsmaxlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldsmin, Ldsmin, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmina, Ldsmina, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminab, Ldsminab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminah, Ldsminah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminal, Ldsminal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminalb, Ldsminalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminalh, Ldsminalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminb, Ldsminb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminh, Ldsminh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminl, Ldsminl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminlb, Ldsminlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminlh, Ldsminlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldtr, Ldtr, Gp, Mem) + ASMJIT_INST_2x(ldtrb, Ldtrb, Gp, Mem) + ASMJIT_INST_2x(ldtrh, Ldtrh, Gp, Mem) + ASMJIT_INST_2x(ldtrsb, Ldtrsb, Gp, Mem) + ASMJIT_INST_2x(ldtrsh, Ldtrsh, Gp, Mem) + ASMJIT_INST_2x(ldtrsw, Ldtrsw, Gp, Mem) + + ASMJIT_INST_3x(ldumax, Ldumax, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxa, Ldumaxa, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxab, Ldumaxab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxah, Ldumaxah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxal, Ldumaxal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxalb, Ldumaxalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxalh, Ldumaxalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxb, Ldumaxb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxh, Ldumaxh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxl, Ldumaxl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxlb, Ldumaxlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxlh, Ldumaxlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldumin, Ldumin, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumina, Ldumina, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminab, Lduminab, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminah, Lduminah, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminal, Lduminal, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminalb, Lduminalb, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminalh, Lduminalh, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminb, Lduminb, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminh, Lduminh, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminl, Lduminl, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminlb, Lduminlb, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminlh, Lduminlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldur, Ldur, Gp, Mem) + ASMJIT_INST_2x(ldurb, Ldurb, Gp, Mem) + ASMJIT_INST_2x(ldurh, Ldurh, Gp, Mem) + ASMJIT_INST_2x(ldursb, Ldursb, Gp, Mem) + ASMJIT_INST_2x(ldursh, Ldursh, Gp, Mem) + ASMJIT_INST_2x(ldursw, Ldursw, Gp, Mem) + + ASMJIT_INST_3x(ldxp, Ldxp, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaxp, Ldaxp, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldxr, Ldxr, Gp, Mem) + ASMJIT_INST_2x(ldxrb, Ldxrb, Gp, Mem) + ASMJIT_INST_2x(ldxrh, Ldxrh, Gp, Mem) + + ASMJIT_INST_2x(prfm, Prfm, Imm, Mem) + + ASMJIT_INST_2x(stadd, Stadd, Gp, Mem) + ASMJIT_INST_2x(staddb, Staddb, Gp, Mem) + ASMJIT_INST_2x(staddh, Staddh, Gp, Mem) + ASMJIT_INST_2x(staddl, Staddl, Gp, Mem) + ASMJIT_INST_2x(staddlb, Staddlb, Gp, Mem) + ASMJIT_INST_2x(staddlh, Staddlh, Gp, Mem) + + ASMJIT_INST_2x(stclr, Stclr, Gp, Mem) + ASMJIT_INST_2x(stclrb, Stclrb, Gp, Mem) + ASMJIT_INST_2x(stclrh, Stclrh, Gp, Mem) + ASMJIT_INST_2x(stclrl, Stclrl, Gp, Mem) + ASMJIT_INST_2x(stclrlb, Stclrlb, Gp, Mem) + ASMJIT_INST_2x(stclrlh, Stclrlh, Gp, Mem) + + ASMJIT_INST_2x(steor, Steor, Gp, Mem) + ASMJIT_INST_2x(steorb, Steorb, Gp, Mem) + ASMJIT_INST_2x(steorh, Steorh, Gp, Mem) + ASMJIT_INST_2x(steorl, Steorl, Gp, Mem) + ASMJIT_INST_2x(steorlb, Steorlb, Gp, Mem) + ASMJIT_INST_2x(steorlh, Steorlh, Gp, Mem) + + ASMJIT_INST_2x(stllr, Stllr, Gp, Mem) + ASMJIT_INST_2x(stllrb, Stllrb, Gp, Mem) + ASMJIT_INST_2x(stllrh, Stllrh, Gp, Mem) + + ASMJIT_INST_2x(stlr, Stllr, Gp, Mem) + ASMJIT_INST_2x(stlrb, Stllrb, Gp, Mem) + ASMJIT_INST_2x(stlrh, Stllrh, Gp, Mem) + + ASMJIT_INST_3x(stlxr, Stlxr, Gp, Gp, Mem) + ASMJIT_INST_3x(stlxrb, Stlxrb, Gp, Gp, Mem) + ASMJIT_INST_3x(stlxrh, Stlxrh, Gp, Gp, Mem) + + ASMJIT_INST_3x(stnp, Stnp, Gp, Gp, Mem) + ASMJIT_INST_3x(stp, Stp, Gp, Gp, Mem) + + ASMJIT_INST_2x(str, Str, Gp, Mem) + ASMJIT_INST_2x(strb, Strb, Gp, Mem) + ASMJIT_INST_2x(strh, Strh, Gp, Mem) + + ASMJIT_INST_2x(stset, Stset, Gp, Mem) + ASMJIT_INST_2x(stsetb, Stsetb, Gp, Mem) + ASMJIT_INST_2x(stseth, Stseth, Gp, Mem) + ASMJIT_INST_2x(stsetl, Stsetl, Gp, Mem) + ASMJIT_INST_2x(stsetlb, Stsetlb, Gp, Mem) + ASMJIT_INST_2x(stsetlh, Stsetlh, Gp, Mem) + + ASMJIT_INST_2x(stsmax, Stsmax, Gp, Mem) + ASMJIT_INST_2x(stsmaxb, Stsmaxb, Gp, Mem) + ASMJIT_INST_2x(stsmaxh, Stsmaxh, Gp, Mem) + ASMJIT_INST_2x(stsmaxl, Stsmaxl, Gp, Mem) + ASMJIT_INST_2x(stsmaxlb, Stsmaxlb, Gp, Mem) + ASMJIT_INST_2x(stsmaxlh, Stsmaxlh, Gp, Mem) + + ASMJIT_INST_2x(stsmin, Stsmin, Gp, Mem) + ASMJIT_INST_2x(stsminb, Stsminb, Gp, Mem) + ASMJIT_INST_2x(stsminh, Stsminh, Gp, Mem) + ASMJIT_INST_2x(stsminl, Stsminl, Gp, Mem) + ASMJIT_INST_2x(stsminlb, Stsminlb, Gp, Mem) + ASMJIT_INST_2x(stsminlh, Stsminlh, Gp, Mem) + + ASMJIT_INST_2x(sttr, Sttr, Gp, Mem) + ASMJIT_INST_2x(sttrb, Sttrb, Gp, Mem) + ASMJIT_INST_2x(sttrh, Sttrh, Gp, Mem) + + ASMJIT_INST_2x(stumax, Stumax, Gp, Mem) + ASMJIT_INST_2x(stumaxb, Stumaxb, Gp, Mem) + ASMJIT_INST_2x(stumaxh, Stumaxh, Gp, Mem) + ASMJIT_INST_2x(stumaxl, Stumaxl, Gp, Mem) + ASMJIT_INST_2x(stumaxlb, Stumaxlb, Gp, Mem) + ASMJIT_INST_2x(stumaxlh, Stumaxlh, Gp, Mem) + + ASMJIT_INST_2x(stumin, Stumin, Gp, Mem) + ASMJIT_INST_2x(stuminb, Stuminb, Gp, Mem) + ASMJIT_INST_2x(stuminh, Stuminh, Gp, Mem) + ASMJIT_INST_2x(stuminl, Stuminl, Gp, Mem) + ASMJIT_INST_2x(stuminlb, Stuminlb, Gp, Mem) + ASMJIT_INST_2x(stuminlh, Stuminlh, Gp, Mem) + + ASMJIT_INST_2x(stur, Stur, Gp, Mem) + ASMJIT_INST_2x(sturb, Sturb, Gp, Mem) + ASMJIT_INST_2x(sturh, Sturh, Gp, Mem) + + ASMJIT_INST_4x(stxp, Stxp, Gp, Gp, Gp, Mem) + ASMJIT_INST_4x(stlxp, Stlxp, Gp, Gp, Gp, Mem) + + ASMJIT_INST_3x(stxr, Stxr, Gp, Gp, Mem) + ASMJIT_INST_3x(stxrb, Stxrb, Gp, Gp, Mem) + ASMJIT_INST_3x(stxrh, Stxrh, Gp, Gp, Mem) + + ASMJIT_INST_3x(swp, Swp, Gp, Gp, Mem) + ASMJIT_INST_3x(swpa, Swpa, Gp, Gp, Mem) + ASMJIT_INST_3x(swpab, Swpab, Gp, Gp, Mem) + ASMJIT_INST_3x(swpah, Swpah, Gp, Gp, Mem) + ASMJIT_INST_3x(swpal, Swpal, Gp, Gp, Mem) + ASMJIT_INST_3x(swpalb, Swpalb, Gp, Gp, Mem) + ASMJIT_INST_3x(swpalh, Swpalh, Gp, Gp, Mem) + ASMJIT_INST_3x(swpb, Swpb, Gp, Gp, Mem) + ASMJIT_INST_3x(swph, Swph, Gp, Gp, Mem) + ASMJIT_INST_3x(swpl, Swpl, Gp, Gp, Mem) + ASMJIT_INST_3x(swplb, Swplb, Gp, Gp, Mem) + ASMJIT_INST_3x(swplh, Swplh, Gp, Gp, Mem) + //! \} + + //! \name CRC Instructions (ARMv8.1-A, optional in ARMv8.0-A) + //! \{ + + ASMJIT_INST_3x(crc32b, Crc32b, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32h, Crc32h, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32w, Crc32w, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32x, Crc32x, Gp, Gp, Gp); + + ASMJIT_INST_3x(crc32cb, Crc32cb, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32ch, Crc32ch, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32cw, Crc32cw, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32cx, Crc32cx, Gp, Gp, Gp); + + //! \} + + //! \name MTE Instructions + //! \{ + + ASMJIT_INST_2x(autda, Autda, Gp, Gp); + ASMJIT_INST_2x(autdb, Autdb, Gp, Gp); + ASMJIT_INST_1x(autdza, Autdza, Gp); + ASMJIT_INST_1x(autdzb, Autdzb, Gp); + ASMJIT_INST_2x(autia, Autia, Gp, Gp); + ASMJIT_INST_0x(autia1716, Autia1716); + ASMJIT_INST_0x(autiasp, Autiasp); + ASMJIT_INST_0x(autiaz, Autiaz); + ASMJIT_INST_2x(autib, Autib, Gp, Gp); + ASMJIT_INST_0x(autib1716, Autib1716); + ASMJIT_INST_0x(autibsp, Autibsp); + ASMJIT_INST_0x(autibz, Autibz); + ASMJIT_INST_1x(autiza, Autiza, Gp); + ASMJIT_INST_1x(autizb, Autizb, Gp); + + ASMJIT_INST_3x(gmi, Gmi, Gp, Gp, Gp); + + ASMJIT_INST_2x(cmpp, Cmpp, Gp, Gp); + ASMJIT_INST_4x(addg, Addg, Gp, Gp, Imm, Imm); + + ASMJIT_INST_2x(ldg, Ldg, Gp, Mem) + ASMJIT_INST_2x(ldgm, Ldgm, Gp, Mem) + ASMJIT_INST_2x(ldraa, Ldraa, Gp, Mem) + ASMJIT_INST_2x(ldrab, Ldrab, Gp, Mem) + + ASMJIT_INST_2x(pacda, Pacda, Gp, Gp); + ASMJIT_INST_2x(pacdb, Pacdb, Gp, Gp); + ASMJIT_INST_1x(pacdza, Pacdza, Gp); + ASMJIT_INST_1x(pacdzb, Pacdzb, Gp); + ASMJIT_INST_3x(pacga, Pacga, Gp, Gp, Gp); + + ASMJIT_INST_3x(subp, Subp, Gp, Gp, Gp); + ASMJIT_INST_3x(subps, Subps, Gp, Gp, Gp); + ASMJIT_INST_4x(subg, Subg, Gp, Gp, Imm, Imm); + + ASMJIT_INST_2x(st2g, St2g, Gp, Mem) + ASMJIT_INST_2x(stg, Stg, Gp, Mem) + ASMJIT_INST_3x(stgp, Stgp, Gp, Gp, Mem) + ASMJIT_INST_2x(stgm, Stgm, Gp, Mem) + ASMJIT_INST_2x(stzg, Stzg, Gp, Mem) + ASMJIT_INST_2x(stz2g, Stz2g, Gp, Mem) + ASMJIT_INST_2x(stzgm, Stzgm, Gp, Mem) + + ASMJIT_INST_1x(xpacd, Xpacd, Gp); + ASMJIT_INST_1x(xpaci, Xpaci, Gp); + ASMJIT_INST_0x(xpaclri, Xpaclri); + + //! \} + + //! \name Hint Instructions + //! \{ + + ASMJIT_INST_1x(hint, Hint, Imm) + ASMJIT_INST_0x(nop, Nop) + ASMJIT_INST_0x(sev, Sev) + ASMJIT_INST_0x(sevl, Sevl) + ASMJIT_INST_0x(wfe, Wfe) + ASMJIT_INST_0x(wfi, Wfi) + ASMJIT_INST_0x(yield, Yield) + + //! \} + + //! \name SIMD & FP Instructions + //! \{ + + ASMJIT_INST_2x(abs, Abs_v, Vec, Vec); + ASMJIT_INST_3x(add, Add_v, Vec, Vec, Vec); + ASMJIT_INST_3x(addhn, Addhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(addhn2, Addhn2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(addp, Addp_v, Vec, Vec); + ASMJIT_INST_3x(addp, Addp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(addv, Addv_v, Vec, Vec); + ASMJIT_INST_3x(and_, And_v, Vec, Vec, Vec); + ASMJIT_INST_2x(bic, Bic_v, Vec, Imm); + ASMJIT_INST_3x(bic, Bic_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bic, Bic_v, Vec, Imm, Imm); + ASMJIT_INST_3x(bif, Bif_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bit, Bit_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bsl, Bsl_v, Vec, Vec, Vec); + ASMJIT_INST_2x(cls, Cls_v, Vec, Vec); + ASMJIT_INST_2x(clz, Clz_v, Vec, Vec); + ASMJIT_INST_3x(cmeq, Cmeq_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmeq, Cmeq_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmge, Cmge_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmge, Cmge_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmgt, Cmgt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmgt, Cmgt_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmhi, Cmhi_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmhs, Cmhs_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmle, Cmle_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmlt, Cmlt_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmtst, Cmtst_v, Vec, Vec, Vec); + ASMJIT_INST_2x(cnt, Cnt_v, Vec, Vec); + ASMJIT_INST_2x(dup, Dup_v, Vec, Gp); + ASMJIT_INST_2x(dup, Dup_v, Vec, Vec); + ASMJIT_INST_3x(eor, Eor_v, Vec, Vec, Vec); + ASMJIT_INST_4x(ext, Ext_v, Vec, Vec, Vec, Imm); + ASMJIT_INST_3x(fabd, Fabd_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fabs, Fabs_v, Vec, Vec); + ASMJIT_INST_3x(facge, Facge_v, Vec, Vec, Vec); + ASMJIT_INST_3x(facgt, Facgt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fadd, Fadd_v, Vec, Vec, Vec); + ASMJIT_INST_2x(faddp, Faddp_v, Vec, Vec); + ASMJIT_INST_3x(faddp, Faddp_v, Vec, Vec, Vec); + ASMJIT_INST_4x(fccmp, Fccmp_v, Vec, Vec, Imm, Imm); + ASMJIT_INST_4x(fccmpe, Fccmpe_v, Vec, Vec, Imm, Imm); + ASMJIT_INST_3x(fcmeq, Fcmeq_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fcmeq, Fcmeq_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmge, Fcmge_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fcmge, Fcmge_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmgt, Fcmgt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fcmgt, Fcmgt_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmle, Fcmle_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmlt, Fcmlt_v, Vec, Vec, Imm); + ASMJIT_INST_2x(fcmp, Fcmp_v, Vec, Vec); + ASMJIT_INST_2x(fcmp, Fcmp_v, Vec, Imm); + ASMJIT_INST_2x(fcmpe, Fcmpe_v, Vec, Vec); + ASMJIT_INST_2x(fcmpe, Fcmpe_v, Vec, Imm); + ASMJIT_INST_4x(fcsel, Fcsel_v, Vec, Vec, Vec, Imm); + ASMJIT_INST_2x(fcvt, Fcvt_v, Vec, Vec); + ASMJIT_INST_2x(fcvtas, Fcvtas_v, Gp, Vec); + ASMJIT_INST_2x(fcvtas, Fcvtas_v, Vec, Vec); + ASMJIT_INST_2x(fcvtau, Fcvtau_v, Gp, Vec); + ASMJIT_INST_2x(fcvtau, Fcvtau_v, Vec, Vec); + ASMJIT_INST_2x(fcvtl, Fcvtl_v, Vec, Vec); + ASMJIT_INST_2x(fcvtl2, Fcvtl2_v, Vec, Vec); + ASMJIT_INST_2x(fcvtms, Fcvtms_v, Gp, Vec); + ASMJIT_INST_2x(fcvtms, Fcvtms_v, Vec, Vec); + ASMJIT_INST_2x(fcvtmu, Fcvtmu_v, Gp, Vec); + ASMJIT_INST_2x(fcvtmu, Fcvtmu_v, Vec, Vec); + ASMJIT_INST_2x(fcvtn, Fcvtn_v, Vec, Vec); + ASMJIT_INST_2x(fcvtn2, Fcvtn2_v, Vec, Vec); + ASMJIT_INST_2x(fcvtns, Fcvtns_v, Gp, Vec); + ASMJIT_INST_2x(fcvtns, Fcvtns_v, Vec, Vec); + ASMJIT_INST_2x(fcvtnu, Fcvtnu_v, Gp, Vec); + ASMJIT_INST_2x(fcvtnu, Fcvtnu_v, Vec, Vec); + ASMJIT_INST_2x(fcvtps, Fcvtps_v, Gp, Vec); + ASMJIT_INST_2x(fcvtps, Fcvtps_v, Vec, Vec); + ASMJIT_INST_2x(fcvtpu, Fcvtpu_v, Gp, Vec); + ASMJIT_INST_2x(fcvtpu, Fcvtpu_v, Vec, Vec); + ASMJIT_INST_2x(fcvtxn, Fcvtxn_v, Vec, Vec); + ASMJIT_INST_2x(fcvtxn2, Fcvtxn2_v, Vec, Vec); + ASMJIT_INST_2x(fcvtzs, Fcvtzs_v, Gp, Vec); + ASMJIT_INST_3x(fcvtzs, Fcvtzs_v, Gp, Vec, Imm); + ASMJIT_INST_2x(fcvtzs, Fcvtzs_v, Vec, Vec); + ASMJIT_INST_3x(fcvtzs, Fcvtzs_v, Vec, Vec, Imm); + ASMJIT_INST_2x(fcvtzu, Fcvtzu_v, Gp, Vec); + ASMJIT_INST_3x(fcvtzu, Fcvtzu_v, Gp, Vec, Imm); + ASMJIT_INST_2x(fcvtzu, Fcvtzu_v, Vec, Vec); + ASMJIT_INST_3x(fcvtzu, Fcvtzu_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fdiv, Fdiv_v, Vec, Vec, Vec); + ASMJIT_INST_4x(fmadd, Fmadd_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(fmax, Fmax_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmaxnm, Fmaxnm_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmaxnmp, Fmaxnmp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fmaxnmp, Fmaxnmp_v, Vec, Vec); + ASMJIT_INST_2x(fmaxnmv, Fmaxnmv_v, Vec, Vec); + ASMJIT_INST_3x(fmaxp, Fmaxp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fmaxp, Fmaxp_v, Vec, Vec); + ASMJIT_INST_2x(fmaxv, Fmaxv_v, Vec, Vec); + ASMJIT_INST_3x(fmin, Fmin_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fminnm, Fminnm_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fminnmv, Fminnmv_v, Vec, Vec); + ASMJIT_INST_3x(fminnmp, Fminnmp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fminnmp, Fminnmp_v, Vec, Vec); + ASMJIT_INST_2x(fminp, Fminp_v, Vec, Vec); + ASMJIT_INST_3x(fminp, Fminp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fminv, Fminv_v, Vec, Vec); + ASMJIT_INST_3x(fmla, Fmla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmls, Fmls_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fmov, Fmov_v, Gp, Vec); + ASMJIT_INST_2x(fmov, Fmov_v, Vec, Gp); + ASMJIT_INST_2x(fmov, Fmov_v, Vec, Vec); + ASMJIT_INST_2x(fmov, Fmov_v, Vec, Imm); + ASMJIT_INST_4x(fmsub, Fmsub_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(fmul, Fmul_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmulx, Fmulx_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fneg, Fneg_v, Vec, Vec); + ASMJIT_INST_4x(fnmadd, Fnmadd_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_4x(fnmsub, Fnmsub_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(fnmul, Fnmul_v, Vec, Vec, Vec); + ASMJIT_INST_2x(frecpe, Frecpe_v, Vec, Vec); + ASMJIT_INST_3x(frecps, Frecps_v, Vec, Vec, Vec); + ASMJIT_INST_2x(frecpx, Frecpx_v, Vec, Vec); + ASMJIT_INST_2x(frint32x, Frint32x_v, Vec, Vec); + ASMJIT_INST_2x(frint32z, Frint32z_v, Vec, Vec); + ASMJIT_INST_2x(frint64x, Frint64x_v, Vec, Vec); + ASMJIT_INST_2x(frint64z, Frint64z_v, Vec, Vec); + ASMJIT_INST_2x(frinta, Frinta_v, Vec, Vec); + ASMJIT_INST_2x(frinti, Frinti_v, Vec, Vec); + ASMJIT_INST_2x(frintm, Frintm_v, Vec, Vec); + ASMJIT_INST_2x(frintn, Frintn_v, Vec, Vec); + ASMJIT_INST_2x(frintp, Frintp_v, Vec, Vec); + ASMJIT_INST_2x(frintx, Frintx_v, Vec, Vec); + ASMJIT_INST_2x(frintz, Frintz_v, Vec, Vec); + ASMJIT_INST_2x(frsqrte, Frsqrte_v, Vec, Vec); + ASMJIT_INST_3x(frsqrts, Frsqrts_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fsqrt, Fsqrt_v, Vec, Vec); + ASMJIT_INST_3x(fsub, Fsub_v, Vec, Vec, Vec); + ASMJIT_INST_2x(ins, Ins_v, Vec, Gp); + ASMJIT_INST_2x(ins, Ins_v, Vec, Vec); + ASMJIT_INST_2x(ld1, Ld1_v, Vec, Mem); + ASMJIT_INST_3x(ld1, Ld1_v, Vec, Vec, Mem); + ASMJIT_INST_4x(ld1, Ld1_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(ld1, Ld1_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_2x(ld1r, Ld1r_v, Vec, Mem); + ASMJIT_INST_3x(ld2, Ld2_v, Vec, Vec, Mem); + ASMJIT_INST_3x(ld2r, Ld2r_v, Vec, Vec, Mem); + ASMJIT_INST_4x(ld3, Ld3_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_4x(ld3r, Ld3r_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(ld4, Ld4_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(ld4r, Ld4r_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_3x(ldnp, Ldnp_v, Vec, Vec, Mem); + ASMJIT_INST_3x(ldp, Ldp_v, Vec, Vec, Mem); + ASMJIT_INST_2x(ldr, Ldr_v, Vec, Mem); + ASMJIT_INST_2x(ldur, Ldur_v, Vec, Mem); + ASMJIT_INST_3x(mla, Mla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(mls, Mls_v, Vec, Vec, Vec); + ASMJIT_INST_2x(mov, Mov_v, Vec, Vec); + ASMJIT_INST_2x(mov, Mov_v, Gp, Vec); + ASMJIT_INST_2x(mov, Mov_v, Vec, Gp); + ASMJIT_INST_2x(movi, Movi_v, Vec, Imm); + ASMJIT_INST_3x(movi, Movi_v, Vec, Imm, Imm); + ASMJIT_INST_3x(mul, Mul_v, Vec, Vec, Vec); + ASMJIT_INST_2x(mvn, Mvn_v, Vec, Vec); + ASMJIT_INST_2x(mvni, Mvni_v, Vec, Imm); + ASMJIT_INST_3x(mvni, Mvni_v, Vec, Imm, Imm); + ASMJIT_INST_2x(neg, Neg_v, Vec, Vec); + ASMJIT_INST_2x(not_, Not_v, Vec, Vec); + ASMJIT_INST_3x(orn, Orn_v, Vec, Vec, Vec); + ASMJIT_INST_2x(orr, Orr_v, Vec, Imm); + ASMJIT_INST_3x(orr, Orr_v, Vec, Vec, Vec); + ASMJIT_INST_3x(orr, Orr_v, Vec, Imm, Imm); + ASMJIT_INST_3x(pmul, Pmul_v, Vec, Vec, Vec); + ASMJIT_INST_3x(pmull, Pmull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(pmull2, Pmull2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(raddhn, Raddhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(raddhn2, Raddhn2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(rbit, Rbit_v, Vec, Vec); + ASMJIT_INST_2x(rev16, Rev16_v, Vec, Vec); + ASMJIT_INST_2x(rev32, Rev32_v, Vec, Vec); + ASMJIT_INST_2x(rev64, Rev64_v, Vec, Vec); + ASMJIT_INST_3x(rshrn, Rshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(rshrn2, Rshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(rsubhn, Rsubhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(rsubhn2, Rsubhn2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(saba, Saba_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabal, Sabal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabal2, Sabal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabd, Sabd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabdl, Sabdl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabdl2, Sabdl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sadalp, Sadalp_v, Vec, Vec); + ASMJIT_INST_3x(saddl, Saddl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(saddl2, Saddl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(saddlp, Saddlp_v, Vec, Vec); + ASMJIT_INST_2x(saddlv, Saddlv_v, Vec, Vec); + ASMJIT_INST_3x(saddw, Saddw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(saddw2, Saddw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(scvtf, Scvtf_v, Vec, Gp); + ASMJIT_INST_3x(scvtf, Scvtf_v, Vec, Gp, Imm); + ASMJIT_INST_2x(scvtf, Scvtf_v, Vec, Vec); + ASMJIT_INST_3x(scvtf, Scvtf_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shadd, Shadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(shl, Shl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shll, Shll_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shll2, Shll2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shrn, Shrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shrn2, Shrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shsub, Shsub_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sli, Sli_v, Vec, Vec, Imm); + ASMJIT_INST_3x(smax, Smax_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smaxp, Smaxp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(smaxv, Smaxv_v, Vec, Vec); + ASMJIT_INST_3x(smin, Smin_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sminp, Sminp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sminv, Sminv_v, Vec, Vec); + ASMJIT_INST_3x(smlal, Smlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smlal2, Smlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smlsl, Smlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smlsl2, Smlsl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(smov, Smov_v, Gp, Vec); + ASMJIT_INST_3x(smull, Smull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smull2, Smull2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sqabs, Sqabs_v, Vec, Vec); + ASMJIT_INST_3x(sqadd, Sqadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlal, Sqdmlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlal2, Sqdmlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlsl, Sqdmlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlsl2, Sqdmlsl2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmulh, Sqdmulh_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmull, Sqdmull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmull2, Sqdmull2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sqneg, Sqneg_v, Vec, Vec); + ASMJIT_INST_3x(sqrdmulh, Sqrdmulh_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqrshl, Sqrshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqrshrn, Sqrshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqrshrn2, Sqrshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqrshrun, Sqrshrun_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqrshrun2, Sqrshrun2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshl, Sqshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqshl, Sqshl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshlu, Sqshlu_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrn, Sqshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrn2, Sqshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrun, Sqshrun_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrun2, Sqshrun2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqsub, Sqsub_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sqxtn, Sqxtn_v, Vec, Vec); + ASMJIT_INST_2x(sqxtn2, Sqxtn2_v, Vec, Vec); + ASMJIT_INST_2x(sqxtun, Sqxtun_v, Vec, Vec); + ASMJIT_INST_2x(sqxtun2, Sqxtun2_v, Vec, Vec); + ASMJIT_INST_3x(srhadd, Srhadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sri, Sri_v, Vec, Vec, Imm); + ASMJIT_INST_3x(srshl, Srshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(srshr, Srshr_v, Vec, Vec, Imm); + ASMJIT_INST_3x(srsra, Srsra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sshl, Sshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sshll, Sshll_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sshll2, Sshll2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sshr, Sshr_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ssra, Ssra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ssubl, Ssubl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ssubl2, Ssubl2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ssubw, Ssubw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ssubw2, Ssubw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(st1, St1_v, Vec, Mem); + ASMJIT_INST_3x(st1, St1_v, Vec, Vec, Mem); + ASMJIT_INST_4x(st1, St1_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(st1, St1_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_3x(st2, St2_v, Vec, Vec, Mem); + ASMJIT_INST_4x(st3, St3_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(st4, St4_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_3x(stnp, Stnp_v, Vec, Vec, Mem); + ASMJIT_INST_3x(stp, Stp_v, Vec, Vec, Mem); + ASMJIT_INST_2x(str, Str_v, Vec, Mem); + ASMJIT_INST_2x(stur, Stur_v, Vec, Mem); + ASMJIT_INST_3x(sub, Sub_v, Vec, Vec, Vec); + ASMJIT_INST_3x(subhn, Subhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(subhn2, Subhn2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(suqadd, Suqadd_v, Vec, Vec); + ASMJIT_INST_2x(sxtl, Sxtl_v, Vec, Vec); + ASMJIT_INST_2x(sxtl2, Sxtl2_v, Vec, Vec); + ASMJIT_INST_3x(tbl, Tbl_v, Vec, Vec, Vec); + ASMJIT_INST_4x(tbl, Tbl_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_5x(tbl, Tbl_v, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_6x(tbl, Tbl_v, Vec, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(tbx, Tbx_v, Vec, Vec, Vec); + ASMJIT_INST_4x(tbx, Tbx_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_5x(tbx, Tbx_v, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_6x(tbx, Tbx_v, Vec, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(trn1, Trn1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(trn2, Trn2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uaba, Uaba_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabal, Uabal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabal2, Uabal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabd, Uabd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabdl, Uabdl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabdl2, Uabdl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uadalp, Uadalp_v, Vec, Vec); + ASMJIT_INST_3x(uaddl, Uaddl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uaddl2, Uaddl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uaddlp, Uaddlp_v, Vec, Vec); + ASMJIT_INST_2x(uaddlv, Uaddlv_v, Vec, Vec); + ASMJIT_INST_3x(uaddw, Uaddw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uaddw2, Uaddw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(ucvtf, Ucvtf_v, Vec, Gp); + ASMJIT_INST_3x(ucvtf, Ucvtf_v, Vec, Gp, Imm); + ASMJIT_INST_2x(ucvtf, Ucvtf_v, Vec, Vec); + ASMJIT_INST_3x(ucvtf, Ucvtf_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uhadd, Uhadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uhsub, Uhsub_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umax, Umax_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umaxp, Umaxp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(umaxv, Umaxv_v, Vec, Vec); + ASMJIT_INST_3x(umin, Umin_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uminp, Uminp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uminv, Uminv_v, Vec, Vec); + ASMJIT_INST_3x(umlal, Umlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umlal2, Umlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umlsl, Umlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umlsl2, Umlsl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(umov, Umov_v, Gp, Vec); + ASMJIT_INST_3x(umull, Umull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umull2, Umull2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqadd, Uqadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqrshl, Uqrshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqrshl, Uqrshl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqrshrn, Uqrshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqrshrn2, Uqrshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqshl, Uqshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqshl, Uqshl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqshrn, Uqshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqshrn2, Uqshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqsub, Uqsub_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uqxtn, Uqxtn_v, Vec, Vec); + ASMJIT_INST_2x(uqxtn2, Uqxtn2_v, Vec, Vec); + ASMJIT_INST_2x(urecpe, Urecpe_v, Vec, Vec); + ASMJIT_INST_3x(urhadd, Urhadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(urshl, Urshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(urshr, Urshr_v, Vec, Vec, Imm); + ASMJIT_INST_2x(ursqrte, Ursqrte_v, Vec, Vec); + ASMJIT_INST_3x(ursra, Ursra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ushl, Ushl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ushll, Ushll_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ushll2, Ushll2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ushr, Ushr_v, Vec, Vec, Imm); + ASMJIT_INST_2x(usqadd, Usqadd_v, Vec, Vec); + ASMJIT_INST_3x(usra, Usra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(usubl, Usubl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usubl2, Usubl2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usubw, Usubw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usubw2, Usubw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uxtl, Uxtl_v, Vec, Vec); + ASMJIT_INST_2x(uxtl2, Uxtl2_v, Vec, Vec); + ASMJIT_INST_3x(uzp1, Uzp1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uzp2, Uzp2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(xtn, Xtn_v, Vec, Vec); + ASMJIT_INST_2x(xtn2, Xtn2_v, Vec, Vec); + ASMJIT_INST_3x(zip1, Zip1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(zip2, Zip2_v, Vec, Vec, Vec); + + //! \} + + //! \name AES Instructions + //! \{ + + ASMJIT_INST_2x(aesd, Aesd_v, Vec, Vec); + ASMJIT_INST_2x(aese, Aese_v, Vec, Vec); + ASMJIT_INST_2x(aesimc, Aesimc_v, Vec, Vec); + ASMJIT_INST_2x(aesmc, Aesmc_v, Vec, Vec); + + //! \} + + //! \name SHA1 Instructions + //! \{ + + ASMJIT_INST_3x(sha1c, Sha1c_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha1h, Sha1h_v, Vec, Vec); + ASMJIT_INST_3x(sha1m, Sha1m_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha1p, Sha1p_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha1su0, Sha1su0_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha1su1, Sha1su1_v, Vec, Vec); + + //! \} + + //! \name SHA2 Instructions + //! \{ + + ASMJIT_INST_3x(sha256h, Sha256h_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha256h2, Sha256h2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha256su0, Sha256su0_v, Vec, Vec); + ASMJIT_INST_3x(sha256su1, Sha256su1_v, Vec, Vec, Vec); + + //! \} + + //! \name RDMA Instructions (ARMv8.1-A) + //! \{ + + ASMJIT_INST_3x(sqrdmlah, Sqrdmlah_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqrdmlsh, Sqrdmlsh_v, Vec, Vec, Vec); + + //! \} + + //! \name FCMA Instruction (ARMv8.3-A) + //! \{ + + ASMJIT_INST_4x(fcadd, Fcadd_v, Vec, Vec, Vec, Imm); + ASMJIT_INST_4x(fcmla, Fcmla_v, Vec, Vec, Vec, Imm); + + //! \} + + //! \name JSCVT Instruction (ARMv8.3-A) + //! \{ + + ASMJIT_INST_2x(fjcvtzs, Fjcvtzs_v, Gp, Vec); + + //! \} + + //! \name FHM Instructions + //! \{ + + ASMJIT_INST_3x(fmlal, Fmlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmlal2, Fmlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmlsl, Fmlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmlsl2, Fmlsl2_v, Vec, Vec, Vec); + + + //! \} + + //! \name SHA3 Instructions (ARMv8.4-A, optional in ARMv8.2-A) + //! \{ + + ASMJIT_INST_4x(bcax, Bcax_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_4x(eor3, Eor3_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(rax1, Rax1_v, Vec, Vec, Vec); + ASMJIT_INST_4x(xar, Xar_v, Vec, Vec, Vec, Imm); + + //! \} + + //! \name SHA512 Instructions (ARMv8.4-A) + //! \{ + + ASMJIT_INST_3x(sha512h, Sha512h_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha512h2, Sha512h2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha512su0, Sha512su0_v, Vec, Vec); + ASMJIT_INST_3x(sha512su1, Sha512su1_v, Vec, Vec, Vec); + + //! \} + + //! \name SM3 Instructions (ARMv8.4-A) + //! \{ + + ASMJIT_INST_3x(sm3partw1, Sm3partw1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3partw2, Sm3partw2_v, Vec, Vec, Vec); + ASMJIT_INST_4x(sm3ss1, Sm3ss1_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt1a, Sm3tt1a_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt1b, Sm3tt1b_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt2a, Sm3tt2a_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt2b, Sm3tt2b_v, Vec, Vec, Vec); + + //! \} + + //! \name SM4 Instructions (ARMv8.4-A) + //! \{ + + ASMJIT_INST_2x(sm4e, Sm4e_v, Vec, Vec); + ASMJIT_INST_3x(sm4ekey, Sm4ekey_v, Vec, Vec, Vec); + + //! \} + + //! \name DOTPROD Instructions (ARMv8.4-A, optional in ARMv8.2-A) + //! \{ + + ASMJIT_INST_3x(sdot, Sdot_v, Vec, Vec, Vec); + ASMJIT_INST_3x(udot, Udot_v, Vec, Vec, Vec); + + //! \} + + //! \name BF16 Instructions (ARMv8.6-A) + //! \{ + + ASMJIT_INST_2x(bfcvt, Bfcvt_v, Vec, Vec); + ASMJIT_INST_2x(bfcvtn, Bfcvtn_v, Vec, Vec); + ASMJIT_INST_2x(bfcvtn2, Bfcvtn2_v, Vec, Vec); + ASMJIT_INST_3x(bfmlalb, Bfmlalb_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bfmlalt, Bfmlalt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bfmmla, Bfmmla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bfdot, Bfdot_v, Vec, Vec, Vec); + + //! \} + + //! \name I8MM Instructions (ARMv8.6-A) + //! \{ + + ASMJIT_INST_3x(smmla, Smmla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sudot, Sudot_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ummla, Ummla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usdot, Usdot_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usmmla, Usmmla_v, Vec, Vec, Vec); + + //! \} +}; + +//! Emitter (ARM). +//! +//! \note This class cannot be instantiated, you can only cast to it and use it as emitter that emits to either +//! `a64::Assembler`, `a64::Builder`, or `a64::Compiler` (use with caution with `a64::Compiler` as it requires +//! virtual registers). +class Emitter : public BaseEmitter, public EmitterExplicitT { + ASMJIT_NONCONSTRUCTIBLE(Emitter) +}; + +//! \} + +#undef ASMJIT_INST_0x +#undef ASMJIT_INST_1x +#undef ASMJIT_INST_2x +#undef ASMJIT_INST_3x +#undef ASMJIT_INST_4x +#undef ASMJIT_INST_5x +#undef ASMJIT_INST_6x +#undef ASMJIT_INST_1cc + +ASMJIT_END_SUB_NAMESPACE + +// Restore undefined MSVC AArch64 macros. +#if defined(ASMJIT_RESTORE_MSVC_AARCH64_MACROS) + #pragma pop_macro("mvn") +#endif + +#endif // ASMJIT_ARM_A64EMITTER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64globals.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64globals.h new file mode 100644 index 0000000000000000000000000000000000000000..8093885b32c1a049c07272f59a1a7e96a87e7ab8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64globals.h @@ -0,0 +1,1895 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64GLOBALS_H_INCLUDED +#define ASMJIT_ARM_A64GLOBALS_H_INCLUDED + +#include "../arm/armglobals.h" + +//! \namespace asmjit::a64 +//! \ingroup asmjit_a64 +//! +//! AArch64 backend. + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 instruction. +//! +//! \note Only used to hold ARM-specific enumerations and static functions. +struct Inst { + //! Instruction id. + enum Id : uint32_t { + // ${InstId:Begin} + kIdNone = 0, //!< Instruction ''. + kIdAdc, //!< Instruction 'adc'. + kIdAdcs, //!< Instruction 'adcs'. + kIdAdd, //!< Instruction 'add'. + kIdAddg, //!< Instruction 'addg'. + kIdAdds, //!< Instruction 'adds'. + kIdAdr, //!< Instruction 'adr'. + kIdAdrp, //!< Instruction 'adrp'. + kIdAnd, //!< Instruction 'and'. + kIdAnds, //!< Instruction 'ands'. + kIdAsr, //!< Instruction 'asr'. + kIdAsrv, //!< Instruction 'asrv'. + kIdAt, //!< Instruction 'at'. + kIdAutda, //!< Instruction 'autda'. + kIdAutdza, //!< Instruction 'autdza'. + kIdAutdb, //!< Instruction 'autdb'. + kIdAutdzb, //!< Instruction 'autdzb'. + kIdAutia, //!< Instruction 'autia'. + kIdAutia1716, //!< Instruction 'autia1716'. + kIdAutiasp, //!< Instruction 'autiasp'. + kIdAutiaz, //!< Instruction 'autiaz'. + kIdAutib, //!< Instruction 'autib'. + kIdAutib1716, //!< Instruction 'autib1716'. + kIdAutibsp, //!< Instruction 'autibsp'. + kIdAutibz, //!< Instruction 'autibz'. + kIdAutiza, //!< Instruction 'autiza'. + kIdAutizb, //!< Instruction 'autizb'. + kIdAxflag, //!< Instruction 'axflag'. + kIdB, //!< Instruction 'b'. + kIdBfc, //!< Instruction 'bfc'. + kIdBfi, //!< Instruction 'bfi'. + kIdBfm, //!< Instruction 'bfm'. + kIdBfxil, //!< Instruction 'bfxil'. + kIdBic, //!< Instruction 'bic'. + kIdBics, //!< Instruction 'bics'. + kIdBl, //!< Instruction 'bl'. + kIdBlr, //!< Instruction 'blr'. + kIdBr, //!< Instruction 'br'. + kIdBrk, //!< Instruction 'brk'. + kIdCas, //!< Instruction 'cas'. + kIdCasa, //!< Instruction 'casa'. + kIdCasab, //!< Instruction 'casab'. + kIdCasah, //!< Instruction 'casah'. + kIdCasal, //!< Instruction 'casal'. + kIdCasalb, //!< Instruction 'casalb'. + kIdCasalh, //!< Instruction 'casalh'. + kIdCasb, //!< Instruction 'casb'. + kIdCash, //!< Instruction 'cash'. + kIdCasl, //!< Instruction 'casl'. + kIdCaslb, //!< Instruction 'caslb'. + kIdCaslh, //!< Instruction 'caslh'. + kIdCasp, //!< Instruction 'casp'. + kIdCaspa, //!< Instruction 'caspa'. + kIdCaspal, //!< Instruction 'caspal'. + kIdCaspl, //!< Instruction 'caspl'. + kIdCbnz, //!< Instruction 'cbnz'. + kIdCbz, //!< Instruction 'cbz'. + kIdCcmn, //!< Instruction 'ccmn'. + kIdCcmp, //!< Instruction 'ccmp'. + kIdCfinv, //!< Instruction 'cfinv'. + kIdCinc, //!< Instruction 'cinc'. + kIdCinv, //!< Instruction 'cinv'. + kIdClrex, //!< Instruction 'clrex'. + kIdCls, //!< Instruction 'cls'. + kIdClz, //!< Instruction 'clz'. + kIdCmn, //!< Instruction 'cmn'. + kIdCmp, //!< Instruction 'cmp'. + kIdCmpp, //!< Instruction 'cmpp'. + kIdCneg, //!< Instruction 'cneg'. + kIdCrc32b, //!< Instruction 'crc32b'. + kIdCrc32cb, //!< Instruction 'crc32cb'. + kIdCrc32ch, //!< Instruction 'crc32ch'. + kIdCrc32cw, //!< Instruction 'crc32cw'. + kIdCrc32cx, //!< Instruction 'crc32cx'. + kIdCrc32h, //!< Instruction 'crc32h'. + kIdCrc32w, //!< Instruction 'crc32w'. + kIdCrc32x, //!< Instruction 'crc32x'. + kIdCsdb, //!< Instruction 'csdb'. + kIdCsel, //!< Instruction 'csel'. + kIdCset, //!< Instruction 'cset'. + kIdCsetm, //!< Instruction 'csetm'. + kIdCsinc, //!< Instruction 'csinc'. + kIdCsinv, //!< Instruction 'csinv'. + kIdCsneg, //!< Instruction 'csneg'. + kIdDc, //!< Instruction 'dc'. + kIdDcps1, //!< Instruction 'dcps1'. + kIdDcps2, //!< Instruction 'dcps2'. + kIdDcps3, //!< Instruction 'dcps3'. + kIdDgh, //!< Instruction 'dgh'. + kIdDmb, //!< Instruction 'dmb'. + kIdDrps, //!< Instruction 'drps'. + kIdDsb, //!< Instruction 'dsb'. + kIdEon, //!< Instruction 'eon'. + kIdEor, //!< Instruction 'eor'. + kIdEsb, //!< Instruction 'esb'. + kIdExtr, //!< Instruction 'extr'. + kIdEret, //!< Instruction 'eret'. + kIdGmi, //!< Instruction 'gmi'. + kIdHint, //!< Instruction 'hint'. + kIdHlt, //!< Instruction 'hlt'. + kIdHvc, //!< Instruction 'hvc'. + kIdIc, //!< Instruction 'ic'. + kIdIsb, //!< Instruction 'isb'. + kIdLdadd, //!< Instruction 'ldadd'. + kIdLdadda, //!< Instruction 'ldadda'. + kIdLdaddab, //!< Instruction 'ldaddab'. + kIdLdaddah, //!< Instruction 'ldaddah'. + kIdLdaddal, //!< Instruction 'ldaddal'. + kIdLdaddalb, //!< Instruction 'ldaddalb'. + kIdLdaddalh, //!< Instruction 'ldaddalh'. + kIdLdaddb, //!< Instruction 'ldaddb'. + kIdLdaddh, //!< Instruction 'ldaddh'. + kIdLdaddl, //!< Instruction 'ldaddl'. + kIdLdaddlb, //!< Instruction 'ldaddlb'. + kIdLdaddlh, //!< Instruction 'ldaddlh'. + kIdLdar, //!< Instruction 'ldar'. + kIdLdarb, //!< Instruction 'ldarb'. + kIdLdarh, //!< Instruction 'ldarh'. + kIdLdaxp, //!< Instruction 'ldaxp'. + kIdLdaxr, //!< Instruction 'ldaxr'. + kIdLdaxrb, //!< Instruction 'ldaxrb'. + kIdLdaxrh, //!< Instruction 'ldaxrh'. + kIdLdclr, //!< Instruction 'ldclr'. + kIdLdclra, //!< Instruction 'ldclra'. + kIdLdclrab, //!< Instruction 'ldclrab'. + kIdLdclrah, //!< Instruction 'ldclrah'. + kIdLdclral, //!< Instruction 'ldclral'. + kIdLdclralb, //!< Instruction 'ldclralb'. + kIdLdclralh, //!< Instruction 'ldclralh'. + kIdLdclrb, //!< Instruction 'ldclrb'. + kIdLdclrh, //!< Instruction 'ldclrh'. + kIdLdclrl, //!< Instruction 'ldclrl'. + kIdLdclrlb, //!< Instruction 'ldclrlb'. + kIdLdclrlh, //!< Instruction 'ldclrlh'. + kIdLdeor, //!< Instruction 'ldeor'. + kIdLdeora, //!< Instruction 'ldeora'. + kIdLdeorab, //!< Instruction 'ldeorab'. + kIdLdeorah, //!< Instruction 'ldeorah'. + kIdLdeoral, //!< Instruction 'ldeoral'. + kIdLdeoralb, //!< Instruction 'ldeoralb'. + kIdLdeoralh, //!< Instruction 'ldeoralh'. + kIdLdeorb, //!< Instruction 'ldeorb'. + kIdLdeorh, //!< Instruction 'ldeorh'. + kIdLdeorl, //!< Instruction 'ldeorl'. + kIdLdeorlb, //!< Instruction 'ldeorlb'. + kIdLdeorlh, //!< Instruction 'ldeorlh'. + kIdLdg, //!< Instruction 'ldg'. + kIdLdgm, //!< Instruction 'ldgm'. + kIdLdlar, //!< Instruction 'ldlar'. + kIdLdlarb, //!< Instruction 'ldlarb'. + kIdLdlarh, //!< Instruction 'ldlarh'. + kIdLdnp, //!< Instruction 'ldnp'. + kIdLdp, //!< Instruction 'ldp'. + kIdLdpsw, //!< Instruction 'ldpsw'. + kIdLdr, //!< Instruction 'ldr'. + kIdLdraa, //!< Instruction 'ldraa'. + kIdLdrab, //!< Instruction 'ldrab'. + kIdLdrb, //!< Instruction 'ldrb'. + kIdLdrh, //!< Instruction 'ldrh'. + kIdLdrsb, //!< Instruction 'ldrsb'. + kIdLdrsh, //!< Instruction 'ldrsh'. + kIdLdrsw, //!< Instruction 'ldrsw'. + kIdLdset, //!< Instruction 'ldset'. + kIdLdseta, //!< Instruction 'ldseta'. + kIdLdsetab, //!< Instruction 'ldsetab'. + kIdLdsetah, //!< Instruction 'ldsetah'. + kIdLdsetal, //!< Instruction 'ldsetal'. + kIdLdsetalb, //!< Instruction 'ldsetalb'. + kIdLdsetalh, //!< Instruction 'ldsetalh'. + kIdLdsetb, //!< Instruction 'ldsetb'. + kIdLdseth, //!< Instruction 'ldseth'. + kIdLdsetl, //!< Instruction 'ldsetl'. + kIdLdsetlb, //!< Instruction 'ldsetlb'. + kIdLdsetlh, //!< Instruction 'ldsetlh'. + kIdLdsmax, //!< Instruction 'ldsmax'. + kIdLdsmaxa, //!< Instruction 'ldsmaxa'. + kIdLdsmaxab, //!< Instruction 'ldsmaxab'. + kIdLdsmaxah, //!< Instruction 'ldsmaxah'. + kIdLdsmaxal, //!< Instruction 'ldsmaxal'. + kIdLdsmaxalb, //!< Instruction 'ldsmaxalb'. + kIdLdsmaxalh, //!< Instruction 'ldsmaxalh'. + kIdLdsmaxb, //!< Instruction 'ldsmaxb'. + kIdLdsmaxh, //!< Instruction 'ldsmaxh'. + kIdLdsmaxl, //!< Instruction 'ldsmaxl'. + kIdLdsmaxlb, //!< Instruction 'ldsmaxlb'. + kIdLdsmaxlh, //!< Instruction 'ldsmaxlh'. + kIdLdsmin, //!< Instruction 'ldsmin'. + kIdLdsmina, //!< Instruction 'ldsmina'. + kIdLdsminab, //!< Instruction 'ldsminab'. + kIdLdsminah, //!< Instruction 'ldsminah'. + kIdLdsminal, //!< Instruction 'ldsminal'. + kIdLdsminalb, //!< Instruction 'ldsminalb'. + kIdLdsminalh, //!< Instruction 'ldsminalh'. + kIdLdsminb, //!< Instruction 'ldsminb'. + kIdLdsminh, //!< Instruction 'ldsminh'. + kIdLdsminl, //!< Instruction 'ldsminl'. + kIdLdsminlb, //!< Instruction 'ldsminlb'. + kIdLdsminlh, //!< Instruction 'ldsminlh'. + kIdLdtr, //!< Instruction 'ldtr'. + kIdLdtrb, //!< Instruction 'ldtrb'. + kIdLdtrh, //!< Instruction 'ldtrh'. + kIdLdtrsb, //!< Instruction 'ldtrsb'. + kIdLdtrsh, //!< Instruction 'ldtrsh'. + kIdLdtrsw, //!< Instruction 'ldtrsw'. + kIdLdumax, //!< Instruction 'ldumax'. + kIdLdumaxa, //!< Instruction 'ldumaxa'. + kIdLdumaxab, //!< Instruction 'ldumaxab'. + kIdLdumaxah, //!< Instruction 'ldumaxah'. + kIdLdumaxal, //!< Instruction 'ldumaxal'. + kIdLdumaxalb, //!< Instruction 'ldumaxalb'. + kIdLdumaxalh, //!< Instruction 'ldumaxalh'. + kIdLdumaxb, //!< Instruction 'ldumaxb'. + kIdLdumaxh, //!< Instruction 'ldumaxh'. + kIdLdumaxl, //!< Instruction 'ldumaxl'. + kIdLdumaxlb, //!< Instruction 'ldumaxlb'. + kIdLdumaxlh, //!< Instruction 'ldumaxlh'. + kIdLdumin, //!< Instruction 'ldumin'. + kIdLdumina, //!< Instruction 'ldumina'. + kIdLduminab, //!< Instruction 'lduminab'. + kIdLduminah, //!< Instruction 'lduminah'. + kIdLduminal, //!< Instruction 'lduminal'. + kIdLduminalb, //!< Instruction 'lduminalb'. + kIdLduminalh, //!< Instruction 'lduminalh'. + kIdLduminb, //!< Instruction 'lduminb'. + kIdLduminh, //!< Instruction 'lduminh'. + kIdLduminl, //!< Instruction 'lduminl'. + kIdLduminlb, //!< Instruction 'lduminlb'. + kIdLduminlh, //!< Instruction 'lduminlh'. + kIdLdur, //!< Instruction 'ldur'. + kIdLdurb, //!< Instruction 'ldurb'. + kIdLdurh, //!< Instruction 'ldurh'. + kIdLdursb, //!< Instruction 'ldursb'. + kIdLdursh, //!< Instruction 'ldursh'. + kIdLdursw, //!< Instruction 'ldursw'. + kIdLdxp, //!< Instruction 'ldxp'. + kIdLdxr, //!< Instruction 'ldxr'. + kIdLdxrb, //!< Instruction 'ldxrb'. + kIdLdxrh, //!< Instruction 'ldxrh'. + kIdLsl, //!< Instruction 'lsl'. + kIdLslv, //!< Instruction 'lslv'. + kIdLsr, //!< Instruction 'lsr'. + kIdLsrv, //!< Instruction 'lsrv'. + kIdMadd, //!< Instruction 'madd'. + kIdMneg, //!< Instruction 'mneg'. + kIdMov, //!< Instruction 'mov'. + kIdMovk, //!< Instruction 'movk'. + kIdMovn, //!< Instruction 'movn'. + kIdMovz, //!< Instruction 'movz'. + kIdMrs, //!< Instruction 'mrs'. + kIdMsr, //!< Instruction 'msr'. + kIdMsub, //!< Instruction 'msub'. + kIdMul, //!< Instruction 'mul'. + kIdMvn, //!< Instruction 'mvn'. + kIdNeg, //!< Instruction 'neg'. + kIdNegs, //!< Instruction 'negs'. + kIdNgc, //!< Instruction 'ngc'. + kIdNgcs, //!< Instruction 'ngcs'. + kIdNop, //!< Instruction 'nop'. + kIdOrn, //!< Instruction 'orn'. + kIdOrr, //!< Instruction 'orr'. + kIdPacda, //!< Instruction 'pacda'. + kIdPacdb, //!< Instruction 'pacdb'. + kIdPacdza, //!< Instruction 'pacdza'. + kIdPacdzb, //!< Instruction 'pacdzb'. + kIdPacga, //!< Instruction 'pacga'. + kIdPrfm, //!< Instruction 'prfm'. + kIdPssbb, //!< Instruction 'pssbb'. + kIdRbit, //!< Instruction 'rbit'. + kIdRet, //!< Instruction 'ret'. + kIdRev, //!< Instruction 'rev'. + kIdRev16, //!< Instruction 'rev16'. + kIdRev32, //!< Instruction 'rev32'. + kIdRev64, //!< Instruction 'rev64'. + kIdRor, //!< Instruction 'ror'. + kIdRorv, //!< Instruction 'rorv'. + kIdSbc, //!< Instruction 'sbc'. + kIdSbcs, //!< Instruction 'sbcs'. + kIdSbfiz, //!< Instruction 'sbfiz'. + kIdSbfm, //!< Instruction 'sbfm'. + kIdSbfx, //!< Instruction 'sbfx'. + kIdSdiv, //!< Instruction 'sdiv'. + kIdSetf8, //!< Instruction 'setf8'. + kIdSetf16, //!< Instruction 'setf16'. + kIdSev, //!< Instruction 'sev'. + kIdSevl, //!< Instruction 'sevl'. + kIdSmaddl, //!< Instruction 'smaddl'. + kIdSmc, //!< Instruction 'smc'. + kIdSmnegl, //!< Instruction 'smnegl'. + kIdSmsubl, //!< Instruction 'smsubl'. + kIdSmulh, //!< Instruction 'smulh'. + kIdSmull, //!< Instruction 'smull'. + kIdSsbb, //!< Instruction 'ssbb'. + kIdSt2g, //!< Instruction 'st2g'. + kIdStadd, //!< Instruction 'stadd'. + kIdStaddl, //!< Instruction 'staddl'. + kIdStaddb, //!< Instruction 'staddb'. + kIdStaddlb, //!< Instruction 'staddlb'. + kIdStaddh, //!< Instruction 'staddh'. + kIdStaddlh, //!< Instruction 'staddlh'. + kIdStclr, //!< Instruction 'stclr'. + kIdStclrl, //!< Instruction 'stclrl'. + kIdStclrb, //!< Instruction 'stclrb'. + kIdStclrlb, //!< Instruction 'stclrlb'. + kIdStclrh, //!< Instruction 'stclrh'. + kIdStclrlh, //!< Instruction 'stclrlh'. + kIdSteor, //!< Instruction 'steor'. + kIdSteorl, //!< Instruction 'steorl'. + kIdSteorb, //!< Instruction 'steorb'. + kIdSteorlb, //!< Instruction 'steorlb'. + kIdSteorh, //!< Instruction 'steorh'. + kIdSteorlh, //!< Instruction 'steorlh'. + kIdStg, //!< Instruction 'stg'. + kIdStgm, //!< Instruction 'stgm'. + kIdStgp, //!< Instruction 'stgp'. + kIdStllr, //!< Instruction 'stllr'. + kIdStllrb, //!< Instruction 'stllrb'. + kIdStllrh, //!< Instruction 'stllrh'. + kIdStlr, //!< Instruction 'stlr'. + kIdStlrb, //!< Instruction 'stlrb'. + kIdStlrh, //!< Instruction 'stlrh'. + kIdStlxp, //!< Instruction 'stlxp'. + kIdStlxr, //!< Instruction 'stlxr'. + kIdStlxrb, //!< Instruction 'stlxrb'. + kIdStlxrh, //!< Instruction 'stlxrh'. + kIdStnp, //!< Instruction 'stnp'. + kIdStp, //!< Instruction 'stp'. + kIdStr, //!< Instruction 'str'. + kIdStrb, //!< Instruction 'strb'. + kIdStrh, //!< Instruction 'strh'. + kIdStset, //!< Instruction 'stset'. + kIdStsetl, //!< Instruction 'stsetl'. + kIdStsetb, //!< Instruction 'stsetb'. + kIdStsetlb, //!< Instruction 'stsetlb'. + kIdStseth, //!< Instruction 'stseth'. + kIdStsetlh, //!< Instruction 'stsetlh'. + kIdStsmax, //!< Instruction 'stsmax'. + kIdStsmaxl, //!< Instruction 'stsmaxl'. + kIdStsmaxb, //!< Instruction 'stsmaxb'. + kIdStsmaxlb, //!< Instruction 'stsmaxlb'. + kIdStsmaxh, //!< Instruction 'stsmaxh'. + kIdStsmaxlh, //!< Instruction 'stsmaxlh'. + kIdStsmin, //!< Instruction 'stsmin'. + kIdStsminl, //!< Instruction 'stsminl'. + kIdStsminb, //!< Instruction 'stsminb'. + kIdStsminlb, //!< Instruction 'stsminlb'. + kIdStsminh, //!< Instruction 'stsminh'. + kIdStsminlh, //!< Instruction 'stsminlh'. + kIdSttr, //!< Instruction 'sttr'. + kIdSttrb, //!< Instruction 'sttrb'. + kIdSttrh, //!< Instruction 'sttrh'. + kIdStumax, //!< Instruction 'stumax'. + kIdStumaxl, //!< Instruction 'stumaxl'. + kIdStumaxb, //!< Instruction 'stumaxb'. + kIdStumaxlb, //!< Instruction 'stumaxlb'. + kIdStumaxh, //!< Instruction 'stumaxh'. + kIdStumaxlh, //!< Instruction 'stumaxlh'. + kIdStumin, //!< Instruction 'stumin'. + kIdStuminl, //!< Instruction 'stuminl'. + kIdStuminb, //!< Instruction 'stuminb'. + kIdStuminlb, //!< Instruction 'stuminlb'. + kIdStuminh, //!< Instruction 'stuminh'. + kIdStuminlh, //!< Instruction 'stuminlh'. + kIdStur, //!< Instruction 'stur'. + kIdSturb, //!< Instruction 'sturb'. + kIdSturh, //!< Instruction 'sturh'. + kIdStxp, //!< Instruction 'stxp'. + kIdStxr, //!< Instruction 'stxr'. + kIdStxrb, //!< Instruction 'stxrb'. + kIdStxrh, //!< Instruction 'stxrh'. + kIdStz2g, //!< Instruction 'stz2g'. + kIdStzg, //!< Instruction 'stzg'. + kIdStzgm, //!< Instruction 'stzgm'. + kIdSub, //!< Instruction 'sub'. + kIdSubg, //!< Instruction 'subg'. + kIdSubp, //!< Instruction 'subp'. + kIdSubps, //!< Instruction 'subps'. + kIdSubs, //!< Instruction 'subs'. + kIdSvc, //!< Instruction 'svc'. + kIdSwp, //!< Instruction 'swp'. + kIdSwpa, //!< Instruction 'swpa'. + kIdSwpab, //!< Instruction 'swpab'. + kIdSwpah, //!< Instruction 'swpah'. + kIdSwpal, //!< Instruction 'swpal'. + kIdSwpalb, //!< Instruction 'swpalb'. + kIdSwpalh, //!< Instruction 'swpalh'. + kIdSwpb, //!< Instruction 'swpb'. + kIdSwph, //!< Instruction 'swph'. + kIdSwpl, //!< Instruction 'swpl'. + kIdSwplb, //!< Instruction 'swplb'. + kIdSwplh, //!< Instruction 'swplh'. + kIdSxtb, //!< Instruction 'sxtb'. + kIdSxth, //!< Instruction 'sxth'. + kIdSxtw, //!< Instruction 'sxtw'. + kIdSys, //!< Instruction 'sys'. + kIdTlbi, //!< Instruction 'tlbi'. + kIdTst, //!< Instruction 'tst'. + kIdTbnz, //!< Instruction 'tbnz'. + kIdTbz, //!< Instruction 'tbz'. + kIdUbfiz, //!< Instruction 'ubfiz'. + kIdUbfm, //!< Instruction 'ubfm'. + kIdUbfx, //!< Instruction 'ubfx'. + kIdUdf, //!< Instruction 'udf'. + kIdUdiv, //!< Instruction 'udiv'. + kIdUmaddl, //!< Instruction 'umaddl'. + kIdUmnegl, //!< Instruction 'umnegl'. + kIdUmull, //!< Instruction 'umull'. + kIdUmulh, //!< Instruction 'umulh'. + kIdUmsubl, //!< Instruction 'umsubl'. + kIdUxtb, //!< Instruction 'uxtb'. + kIdUxth, //!< Instruction 'uxth'. + kIdWfe, //!< Instruction 'wfe'. + kIdWfi, //!< Instruction 'wfi'. + kIdXaflag, //!< Instruction 'xaflag'. + kIdXpacd, //!< Instruction 'xpacd'. + kIdXpaci, //!< Instruction 'xpaci'. + kIdXpaclri, //!< Instruction 'xpaclri'. + kIdYield, //!< Instruction 'yield'. + kIdAbs_v, //!< Instruction 'abs' {ASIMD}. + kIdAdd_v, //!< Instruction 'add' {ASIMD}. + kIdAddhn_v, //!< Instruction 'addhn' {ASIMD}. + kIdAddhn2_v, //!< Instruction 'addhn2' {ASIMD}. + kIdAddp_v, //!< Instruction 'addp' {ASIMD}. + kIdAddv_v, //!< Instruction 'addv' {ASIMD}. + kIdAesd_v, //!< Instruction 'aesd' {ASIMD}. + kIdAese_v, //!< Instruction 'aese' {ASIMD}. + kIdAesimc_v, //!< Instruction 'aesimc' {ASIMD}. + kIdAesmc_v, //!< Instruction 'aesmc' {ASIMD}. + kIdAnd_v, //!< Instruction 'and' {ASIMD}. + kIdBcax_v, //!< Instruction 'bcax' {ASIMD}. + kIdBfcvt_v, //!< Instruction 'bfcvt' {ASIMD}. + kIdBfcvtn_v, //!< Instruction 'bfcvtn' {ASIMD}. + kIdBfcvtn2_v, //!< Instruction 'bfcvtn2' {ASIMD}. + kIdBfdot_v, //!< Instruction 'bfdot' {ASIMD}. + kIdBfmlalb_v, //!< Instruction 'bfmlalb' {ASIMD}. + kIdBfmlalt_v, //!< Instruction 'bfmlalt' {ASIMD}. + kIdBfmmla_v, //!< Instruction 'bfmmla' {ASIMD}. + kIdBic_v, //!< Instruction 'bic' {ASIMD}. + kIdBif_v, //!< Instruction 'bif' {ASIMD}. + kIdBit_v, //!< Instruction 'bit' {ASIMD}. + kIdBsl_v, //!< Instruction 'bsl' {ASIMD}. + kIdCls_v, //!< Instruction 'cls' {ASIMD}. + kIdClz_v, //!< Instruction 'clz' {ASIMD}. + kIdCmeq_v, //!< Instruction 'cmeq' {ASIMD}. + kIdCmge_v, //!< Instruction 'cmge' {ASIMD}. + kIdCmgt_v, //!< Instruction 'cmgt' {ASIMD}. + kIdCmhi_v, //!< Instruction 'cmhi' {ASIMD}. + kIdCmhs_v, //!< Instruction 'cmhs' {ASIMD}. + kIdCmle_v, //!< Instruction 'cmle' {ASIMD}. + kIdCmlt_v, //!< Instruction 'cmlt' {ASIMD}. + kIdCmtst_v, //!< Instruction 'cmtst' {ASIMD}. + kIdCnt_v, //!< Instruction 'cnt' {ASIMD}. + kIdDup_v, //!< Instruction 'dup' {ASIMD}. + kIdEor_v, //!< Instruction 'eor' {ASIMD}. + kIdEor3_v, //!< Instruction 'eor3' {ASIMD}. + kIdExt_v, //!< Instruction 'ext' {ASIMD}. + kIdFabd_v, //!< Instruction 'fabd' {ASIMD}. + kIdFabs_v, //!< Instruction 'fabs' {ASIMD}. + kIdFacge_v, //!< Instruction 'facge' {ASIMD}. + kIdFacgt_v, //!< Instruction 'facgt' {ASIMD}. + kIdFadd_v, //!< Instruction 'fadd' {ASIMD}. + kIdFaddp_v, //!< Instruction 'faddp' {ASIMD}. + kIdFcadd_v, //!< Instruction 'fcadd' {ASIMD}. + kIdFccmp_v, //!< Instruction 'fccmp' {ASIMD}. + kIdFccmpe_v, //!< Instruction 'fccmpe' {ASIMD}. + kIdFcmeq_v, //!< Instruction 'fcmeq' {ASIMD}. + kIdFcmge_v, //!< Instruction 'fcmge' {ASIMD}. + kIdFcmgt_v, //!< Instruction 'fcmgt' {ASIMD}. + kIdFcmla_v, //!< Instruction 'fcmla' {ASIMD}. + kIdFcmle_v, //!< Instruction 'fcmle' {ASIMD}. + kIdFcmlt_v, //!< Instruction 'fcmlt' {ASIMD}. + kIdFcmp_v, //!< Instruction 'fcmp' {ASIMD}. + kIdFcmpe_v, //!< Instruction 'fcmpe' {ASIMD}. + kIdFcsel_v, //!< Instruction 'fcsel' {ASIMD}. + kIdFcvt_v, //!< Instruction 'fcvt' {ASIMD}. + kIdFcvtas_v, //!< Instruction 'fcvtas' {ASIMD}. + kIdFcvtau_v, //!< Instruction 'fcvtau' {ASIMD}. + kIdFcvtl_v, //!< Instruction 'fcvtl' {ASIMD}. + kIdFcvtl2_v, //!< Instruction 'fcvtl2' {ASIMD}. + kIdFcvtms_v, //!< Instruction 'fcvtms' {ASIMD}. + kIdFcvtmu_v, //!< Instruction 'fcvtmu' {ASIMD}. + kIdFcvtn_v, //!< Instruction 'fcvtn' {ASIMD}. + kIdFcvtn2_v, //!< Instruction 'fcvtn2' {ASIMD}. + kIdFcvtns_v, //!< Instruction 'fcvtns' {ASIMD}. + kIdFcvtnu_v, //!< Instruction 'fcvtnu' {ASIMD}. + kIdFcvtps_v, //!< Instruction 'fcvtps' {ASIMD}. + kIdFcvtpu_v, //!< Instruction 'fcvtpu' {ASIMD}. + kIdFcvtxn_v, //!< Instruction 'fcvtxn' {ASIMD}. + kIdFcvtxn2_v, //!< Instruction 'fcvtxn2' {ASIMD}. + kIdFcvtzs_v, //!< Instruction 'fcvtzs' {ASIMD}. + kIdFcvtzu_v, //!< Instruction 'fcvtzu' {ASIMD}. + kIdFdiv_v, //!< Instruction 'fdiv' {ASIMD}. + kIdFjcvtzs_v, //!< Instruction 'fjcvtzs' {ASIMD}. + kIdFmadd_v, //!< Instruction 'fmadd' {ASIMD}. + kIdFmax_v, //!< Instruction 'fmax' {ASIMD}. + kIdFmaxnm_v, //!< Instruction 'fmaxnm' {ASIMD}. + kIdFmaxnmp_v, //!< Instruction 'fmaxnmp' {ASIMD}. + kIdFmaxnmv_v, //!< Instruction 'fmaxnmv' {ASIMD}. + kIdFmaxp_v, //!< Instruction 'fmaxp' {ASIMD}. + kIdFmaxv_v, //!< Instruction 'fmaxv' {ASIMD}. + kIdFmin_v, //!< Instruction 'fmin' {ASIMD}. + kIdFminnm_v, //!< Instruction 'fminnm' {ASIMD}. + kIdFminnmp_v, //!< Instruction 'fminnmp' {ASIMD}. + kIdFminnmv_v, //!< Instruction 'fminnmv' {ASIMD}. + kIdFminp_v, //!< Instruction 'fminp' {ASIMD}. + kIdFminv_v, //!< Instruction 'fminv' {ASIMD}. + kIdFmla_v, //!< Instruction 'fmla' {ASIMD}. + kIdFmlal_v, //!< Instruction 'fmlal' {ASIMD}. + kIdFmlal2_v, //!< Instruction 'fmlal2' {ASIMD}. + kIdFmls_v, //!< Instruction 'fmls' {ASIMD}. + kIdFmlsl_v, //!< Instruction 'fmlsl' {ASIMD}. + kIdFmlsl2_v, //!< Instruction 'fmlsl2' {ASIMD}. + kIdFmov_v, //!< Instruction 'fmov' {ASIMD}. + kIdFmsub_v, //!< Instruction 'fmsub' {ASIMD}. + kIdFmul_v, //!< Instruction 'fmul' {ASIMD}. + kIdFmulx_v, //!< Instruction 'fmulx' {ASIMD}. + kIdFneg_v, //!< Instruction 'fneg' {ASIMD}. + kIdFnmadd_v, //!< Instruction 'fnmadd' {ASIMD}. + kIdFnmsub_v, //!< Instruction 'fnmsub' {ASIMD}. + kIdFnmul_v, //!< Instruction 'fnmul' {ASIMD}. + kIdFrecpe_v, //!< Instruction 'frecpe' {ASIMD}. + kIdFrecps_v, //!< Instruction 'frecps' {ASIMD}. + kIdFrecpx_v, //!< Instruction 'frecpx' {ASIMD}. + kIdFrint32x_v, //!< Instruction 'frint32x' {ASIMD}. + kIdFrint32z_v, //!< Instruction 'frint32z' {ASIMD}. + kIdFrint64x_v, //!< Instruction 'frint64x' {ASIMD}. + kIdFrint64z_v, //!< Instruction 'frint64z' {ASIMD}. + kIdFrinta_v, //!< Instruction 'frinta' {ASIMD}. + kIdFrinti_v, //!< Instruction 'frinti' {ASIMD}. + kIdFrintm_v, //!< Instruction 'frintm' {ASIMD}. + kIdFrintn_v, //!< Instruction 'frintn' {ASIMD}. + kIdFrintp_v, //!< Instruction 'frintp' {ASIMD}. + kIdFrintx_v, //!< Instruction 'frintx' {ASIMD}. + kIdFrintz_v, //!< Instruction 'frintz' {ASIMD}. + kIdFrsqrte_v, //!< Instruction 'frsqrte' {ASIMD}. + kIdFrsqrts_v, //!< Instruction 'frsqrts' {ASIMD}. + kIdFsqrt_v, //!< Instruction 'fsqrt' {ASIMD}. + kIdFsub_v, //!< Instruction 'fsub' {ASIMD}. + kIdIns_v, //!< Instruction 'ins' {ASIMD}. + kIdLd1_v, //!< Instruction 'ld1' {ASIMD}. + kIdLd1r_v, //!< Instruction 'ld1r' {ASIMD}. + kIdLd2_v, //!< Instruction 'ld2' {ASIMD}. + kIdLd2r_v, //!< Instruction 'ld2r' {ASIMD}. + kIdLd3_v, //!< Instruction 'ld3' {ASIMD}. + kIdLd3r_v, //!< Instruction 'ld3r' {ASIMD}. + kIdLd4_v, //!< Instruction 'ld4' {ASIMD}. + kIdLd4r_v, //!< Instruction 'ld4r' {ASIMD}. + kIdLdnp_v, //!< Instruction 'ldnp' {ASIMD}. + kIdLdp_v, //!< Instruction 'ldp' {ASIMD}. + kIdLdr_v, //!< Instruction 'ldr' {ASIMD}. + kIdLdur_v, //!< Instruction 'ldur' {ASIMD}. + kIdMla_v, //!< Instruction 'mla' {ASIMD}. + kIdMls_v, //!< Instruction 'mls' {ASIMD}. + kIdMov_v, //!< Instruction 'mov' {ASIMD}. + kIdMovi_v, //!< Instruction 'movi' {ASIMD}. + kIdMul_v, //!< Instruction 'mul' {ASIMD}. + kIdMvn_v, //!< Instruction 'mvn' {ASIMD}. + kIdMvni_v, //!< Instruction 'mvni' {ASIMD}. + kIdNeg_v, //!< Instruction 'neg' {ASIMD}. + kIdNot_v, //!< Instruction 'not' {ASIMD}. + kIdOrn_v, //!< Instruction 'orn' {ASIMD}. + kIdOrr_v, //!< Instruction 'orr' {ASIMD}. + kIdPmul_v, //!< Instruction 'pmul' {ASIMD}. + kIdPmull_v, //!< Instruction 'pmull' {ASIMD}. + kIdPmull2_v, //!< Instruction 'pmull2' {ASIMD}. + kIdRaddhn_v, //!< Instruction 'raddhn' {ASIMD}. + kIdRaddhn2_v, //!< Instruction 'raddhn2' {ASIMD}. + kIdRax1_v, //!< Instruction 'rax1' {ASIMD}. + kIdRbit_v, //!< Instruction 'rbit' {ASIMD}. + kIdRev16_v, //!< Instruction 'rev16' {ASIMD}. + kIdRev32_v, //!< Instruction 'rev32' {ASIMD}. + kIdRev64_v, //!< Instruction 'rev64' {ASIMD}. + kIdRshrn_v, //!< Instruction 'rshrn' {ASIMD}. + kIdRshrn2_v, //!< Instruction 'rshrn2' {ASIMD}. + kIdRsubhn_v, //!< Instruction 'rsubhn' {ASIMD}. + kIdRsubhn2_v, //!< Instruction 'rsubhn2' {ASIMD}. + kIdSaba_v, //!< Instruction 'saba' {ASIMD}. + kIdSabal_v, //!< Instruction 'sabal' {ASIMD}. + kIdSabal2_v, //!< Instruction 'sabal2' {ASIMD}. + kIdSabd_v, //!< Instruction 'sabd' {ASIMD}. + kIdSabdl_v, //!< Instruction 'sabdl' {ASIMD}. + kIdSabdl2_v, //!< Instruction 'sabdl2' {ASIMD}. + kIdSadalp_v, //!< Instruction 'sadalp' {ASIMD}. + kIdSaddl_v, //!< Instruction 'saddl' {ASIMD}. + kIdSaddl2_v, //!< Instruction 'saddl2' {ASIMD}. + kIdSaddlp_v, //!< Instruction 'saddlp' {ASIMD}. + kIdSaddlv_v, //!< Instruction 'saddlv' {ASIMD}. + kIdSaddw_v, //!< Instruction 'saddw' {ASIMD}. + kIdSaddw2_v, //!< Instruction 'saddw2' {ASIMD}. + kIdScvtf_v, //!< Instruction 'scvtf' {ASIMD}. + kIdSdot_v, //!< Instruction 'sdot' {ASIMD}. + kIdSha1c_v, //!< Instruction 'sha1c' {ASIMD}. + kIdSha1h_v, //!< Instruction 'sha1h' {ASIMD}. + kIdSha1m_v, //!< Instruction 'sha1m' {ASIMD}. + kIdSha1p_v, //!< Instruction 'sha1p' {ASIMD}. + kIdSha1su0_v, //!< Instruction 'sha1su0' {ASIMD}. + kIdSha1su1_v, //!< Instruction 'sha1su1' {ASIMD}. + kIdSha256h_v, //!< Instruction 'sha256h' {ASIMD}. + kIdSha256h2_v, //!< Instruction 'sha256h2' {ASIMD}. + kIdSha256su0_v, //!< Instruction 'sha256su0' {ASIMD}. + kIdSha256su1_v, //!< Instruction 'sha256su1' {ASIMD}. + kIdSha512h_v, //!< Instruction 'sha512h' {ASIMD}. + kIdSha512h2_v, //!< Instruction 'sha512h2' {ASIMD}. + kIdSha512su0_v, //!< Instruction 'sha512su0' {ASIMD}. + kIdSha512su1_v, //!< Instruction 'sha512su1' {ASIMD}. + kIdShadd_v, //!< Instruction 'shadd' {ASIMD}. + kIdShl_v, //!< Instruction 'shl' {ASIMD}. + kIdShll_v, //!< Instruction 'shll' {ASIMD}. + kIdShll2_v, //!< Instruction 'shll2' {ASIMD}. + kIdShrn_v, //!< Instruction 'shrn' {ASIMD}. + kIdShrn2_v, //!< Instruction 'shrn2' {ASIMD}. + kIdShsub_v, //!< Instruction 'shsub' {ASIMD}. + kIdSli_v, //!< Instruction 'sli' {ASIMD}. + kIdSm3partw1_v, //!< Instruction 'sm3partw1' {ASIMD}. + kIdSm3partw2_v, //!< Instruction 'sm3partw2' {ASIMD}. + kIdSm3ss1_v, //!< Instruction 'sm3ss1' {ASIMD}. + kIdSm3tt1a_v, //!< Instruction 'sm3tt1a' {ASIMD}. + kIdSm3tt1b_v, //!< Instruction 'sm3tt1b' {ASIMD}. + kIdSm3tt2a_v, //!< Instruction 'sm3tt2a' {ASIMD}. + kIdSm3tt2b_v, //!< Instruction 'sm3tt2b' {ASIMD}. + kIdSm4e_v, //!< Instruction 'sm4e' {ASIMD}. + kIdSm4ekey_v, //!< Instruction 'sm4ekey' {ASIMD}. + kIdSmax_v, //!< Instruction 'smax' {ASIMD}. + kIdSmaxp_v, //!< Instruction 'smaxp' {ASIMD}. + kIdSmaxv_v, //!< Instruction 'smaxv' {ASIMD}. + kIdSmin_v, //!< Instruction 'smin' {ASIMD}. + kIdSminp_v, //!< Instruction 'sminp' {ASIMD}. + kIdSminv_v, //!< Instruction 'sminv' {ASIMD}. + kIdSmlal_v, //!< Instruction 'smlal' {ASIMD}. + kIdSmlal2_v, //!< Instruction 'smlal2' {ASIMD}. + kIdSmlsl_v, //!< Instruction 'smlsl' {ASIMD}. + kIdSmlsl2_v, //!< Instruction 'smlsl2' {ASIMD}. + kIdSmmla_v, //!< Instruction 'smmla' {ASIMD}. + kIdSmov_v, //!< Instruction 'smov' {ASIMD}. + kIdSmull_v, //!< Instruction 'smull' {ASIMD}. + kIdSmull2_v, //!< Instruction 'smull2' {ASIMD}. + kIdSqabs_v, //!< Instruction 'sqabs' {ASIMD}. + kIdSqadd_v, //!< Instruction 'sqadd' {ASIMD}. + kIdSqdmlal_v, //!< Instruction 'sqdmlal' {ASIMD}. + kIdSqdmlal2_v, //!< Instruction 'sqdmlal2' {ASIMD}. + kIdSqdmlsl_v, //!< Instruction 'sqdmlsl' {ASIMD}. + kIdSqdmlsl2_v, //!< Instruction 'sqdmlsl2' {ASIMD}. + kIdSqdmulh_v, //!< Instruction 'sqdmulh' {ASIMD}. + kIdSqdmull_v, //!< Instruction 'sqdmull' {ASIMD}. + kIdSqdmull2_v, //!< Instruction 'sqdmull2' {ASIMD}. + kIdSqneg_v, //!< Instruction 'sqneg' {ASIMD}. + kIdSqrdmlah_v, //!< Instruction 'sqrdmlah' {ASIMD}. + kIdSqrdmlsh_v, //!< Instruction 'sqrdmlsh' {ASIMD}. + kIdSqrdmulh_v, //!< Instruction 'sqrdmulh' {ASIMD}. + kIdSqrshl_v, //!< Instruction 'sqrshl' {ASIMD}. + kIdSqrshrn_v, //!< Instruction 'sqrshrn' {ASIMD}. + kIdSqrshrn2_v, //!< Instruction 'sqrshrn2' {ASIMD}. + kIdSqrshrun_v, //!< Instruction 'sqrshrun' {ASIMD}. + kIdSqrshrun2_v, //!< Instruction 'sqrshrun2' {ASIMD}. + kIdSqshl_v, //!< Instruction 'sqshl' {ASIMD}. + kIdSqshlu_v, //!< Instruction 'sqshlu' {ASIMD}. + kIdSqshrn_v, //!< Instruction 'sqshrn' {ASIMD}. + kIdSqshrn2_v, //!< Instruction 'sqshrn2' {ASIMD}. + kIdSqshrun_v, //!< Instruction 'sqshrun' {ASIMD}. + kIdSqshrun2_v, //!< Instruction 'sqshrun2' {ASIMD}. + kIdSqsub_v, //!< Instruction 'sqsub' {ASIMD}. + kIdSqxtn_v, //!< Instruction 'sqxtn' {ASIMD}. + kIdSqxtn2_v, //!< Instruction 'sqxtn2' {ASIMD}. + kIdSqxtun_v, //!< Instruction 'sqxtun' {ASIMD}. + kIdSqxtun2_v, //!< Instruction 'sqxtun2' {ASIMD}. + kIdSrhadd_v, //!< Instruction 'srhadd' {ASIMD}. + kIdSri_v, //!< Instruction 'sri' {ASIMD}. + kIdSrshl_v, //!< Instruction 'srshl' {ASIMD}. + kIdSrshr_v, //!< Instruction 'srshr' {ASIMD}. + kIdSrsra_v, //!< Instruction 'srsra' {ASIMD}. + kIdSshl_v, //!< Instruction 'sshl' {ASIMD}. + kIdSshll_v, //!< Instruction 'sshll' {ASIMD}. + kIdSshll2_v, //!< Instruction 'sshll2' {ASIMD}. + kIdSshr_v, //!< Instruction 'sshr' {ASIMD}. + kIdSsra_v, //!< Instruction 'ssra' {ASIMD}. + kIdSsubl_v, //!< Instruction 'ssubl' {ASIMD}. + kIdSsubl2_v, //!< Instruction 'ssubl2' {ASIMD}. + kIdSsubw_v, //!< Instruction 'ssubw' {ASIMD}. + kIdSsubw2_v, //!< Instruction 'ssubw2' {ASIMD}. + kIdSt1_v, //!< Instruction 'st1' {ASIMD}. + kIdSt2_v, //!< Instruction 'st2' {ASIMD}. + kIdSt3_v, //!< Instruction 'st3' {ASIMD}. + kIdSt4_v, //!< Instruction 'st4' {ASIMD}. + kIdStnp_v, //!< Instruction 'stnp' {ASIMD}. + kIdStp_v, //!< Instruction 'stp' {ASIMD}. + kIdStr_v, //!< Instruction 'str' {ASIMD}. + kIdStur_v, //!< Instruction 'stur' {ASIMD}. + kIdSub_v, //!< Instruction 'sub' {ASIMD}. + kIdSubhn_v, //!< Instruction 'subhn' {ASIMD}. + kIdSubhn2_v, //!< Instruction 'subhn2' {ASIMD}. + kIdSudot_v, //!< Instruction 'sudot' {ASIMD}. + kIdSuqadd_v, //!< Instruction 'suqadd' {ASIMD}. + kIdSxtl_v, //!< Instruction 'sxtl' {ASIMD}. + kIdSxtl2_v, //!< Instruction 'sxtl2' {ASIMD}. + kIdTbl_v, //!< Instruction 'tbl' {ASIMD}. + kIdTbx_v, //!< Instruction 'tbx' {ASIMD}. + kIdTrn1_v, //!< Instruction 'trn1' {ASIMD}. + kIdTrn2_v, //!< Instruction 'trn2' {ASIMD}. + kIdUaba_v, //!< Instruction 'uaba' {ASIMD}. + kIdUabal_v, //!< Instruction 'uabal' {ASIMD}. + kIdUabal2_v, //!< Instruction 'uabal2' {ASIMD}. + kIdUabd_v, //!< Instruction 'uabd' {ASIMD}. + kIdUabdl_v, //!< Instruction 'uabdl' {ASIMD}. + kIdUabdl2_v, //!< Instruction 'uabdl2' {ASIMD}. + kIdUadalp_v, //!< Instruction 'uadalp' {ASIMD}. + kIdUaddl_v, //!< Instruction 'uaddl' {ASIMD}. + kIdUaddl2_v, //!< Instruction 'uaddl2' {ASIMD}. + kIdUaddlp_v, //!< Instruction 'uaddlp' {ASIMD}. + kIdUaddlv_v, //!< Instruction 'uaddlv' {ASIMD}. + kIdUaddw_v, //!< Instruction 'uaddw' {ASIMD}. + kIdUaddw2_v, //!< Instruction 'uaddw2' {ASIMD}. + kIdUcvtf_v, //!< Instruction 'ucvtf' {ASIMD}. + kIdUdot_v, //!< Instruction 'udot' {ASIMD}. + kIdUhadd_v, //!< Instruction 'uhadd' {ASIMD}. + kIdUhsub_v, //!< Instruction 'uhsub' {ASIMD}. + kIdUmax_v, //!< Instruction 'umax' {ASIMD}. + kIdUmaxp_v, //!< Instruction 'umaxp' {ASIMD}. + kIdUmaxv_v, //!< Instruction 'umaxv' {ASIMD}. + kIdUmin_v, //!< Instruction 'umin' {ASIMD}. + kIdUminp_v, //!< Instruction 'uminp' {ASIMD}. + kIdUminv_v, //!< Instruction 'uminv' {ASIMD}. + kIdUmlal_v, //!< Instruction 'umlal' {ASIMD}. + kIdUmlal2_v, //!< Instruction 'umlal2' {ASIMD}. + kIdUmlsl_v, //!< Instruction 'umlsl' {ASIMD}. + kIdUmlsl2_v, //!< Instruction 'umlsl2' {ASIMD}. + kIdUmmla_v, //!< Instruction 'ummla' {ASIMD}. + kIdUmov_v, //!< Instruction 'umov' {ASIMD}. + kIdUmull_v, //!< Instruction 'umull' {ASIMD}. + kIdUmull2_v, //!< Instruction 'umull2' {ASIMD}. + kIdUqadd_v, //!< Instruction 'uqadd' {ASIMD}. + kIdUqrshl_v, //!< Instruction 'uqrshl' {ASIMD}. + kIdUqrshrn_v, //!< Instruction 'uqrshrn' {ASIMD}. + kIdUqrshrn2_v, //!< Instruction 'uqrshrn2' {ASIMD}. + kIdUqshl_v, //!< Instruction 'uqshl' {ASIMD}. + kIdUqshrn_v, //!< Instruction 'uqshrn' {ASIMD}. + kIdUqshrn2_v, //!< Instruction 'uqshrn2' {ASIMD}. + kIdUqsub_v, //!< Instruction 'uqsub' {ASIMD}. + kIdUqxtn_v, //!< Instruction 'uqxtn' {ASIMD}. + kIdUqxtn2_v, //!< Instruction 'uqxtn2' {ASIMD}. + kIdUrecpe_v, //!< Instruction 'urecpe' {ASIMD}. + kIdUrhadd_v, //!< Instruction 'urhadd' {ASIMD}. + kIdUrshl_v, //!< Instruction 'urshl' {ASIMD}. + kIdUrshr_v, //!< Instruction 'urshr' {ASIMD}. + kIdUrsqrte_v, //!< Instruction 'ursqrte' {ASIMD}. + kIdUrsra_v, //!< Instruction 'ursra' {ASIMD}. + kIdUsdot_v, //!< Instruction 'usdot' {ASIMD}. + kIdUshl_v, //!< Instruction 'ushl' {ASIMD}. + kIdUshll_v, //!< Instruction 'ushll' {ASIMD}. + kIdUshll2_v, //!< Instruction 'ushll2' {ASIMD}. + kIdUshr_v, //!< Instruction 'ushr' {ASIMD}. + kIdUsmmla_v, //!< Instruction 'usmmla' {ASIMD}. + kIdUsqadd_v, //!< Instruction 'usqadd' {ASIMD}. + kIdUsra_v, //!< Instruction 'usra' {ASIMD}. + kIdUsubl_v, //!< Instruction 'usubl' {ASIMD}. + kIdUsubl2_v, //!< Instruction 'usubl2' {ASIMD}. + kIdUsubw_v, //!< Instruction 'usubw' {ASIMD}. + kIdUsubw2_v, //!< Instruction 'usubw2' {ASIMD}. + kIdUxtl_v, //!< Instruction 'uxtl' {ASIMD}. + kIdUxtl2_v, //!< Instruction 'uxtl2' {ASIMD}. + kIdUzp1_v, //!< Instruction 'uzp1' {ASIMD}. + kIdUzp2_v, //!< Instruction 'uzp2' {ASIMD}. + kIdXar_v, //!< Instruction 'xar' {ASIMD}. + kIdXtn_v, //!< Instruction 'xtn' {ASIMD}. + kIdXtn2_v, //!< Instruction 'xtn2' {ASIMD}. + kIdZip1_v, //!< Instruction 'zip1' {ASIMD}. + kIdZip2_v, //!< Instruction 'zip2' {ASIMD}. + _kIdCount + // ${InstId:End} + }; + + //! Tests whether the `instId` is defined (counts also Inst::kIdNone, which must be zero). + static ASMJIT_INLINE_NODEBUG bool isDefinedId(InstId instId) noexcept { return (instId & uint32_t(InstIdParts::kRealId)) < _kIdCount; } +}; + +namespace Predicate { + +//! Address translate options (AT). +namespace AT { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + enum Value : uint32_t { + kS1E1R = encode(0b000, 0b0111, 0b1000, 0b000), + kS1E2R = encode(0b100, 0b0111, 0b1000, 0b000), + kS1E3R = encode(0b110, 0b0111, 0b1000, 0b000), + kS1E1W = encode(0b000, 0b0111, 0b1000, 0b001), + kS1E2W = encode(0b100, 0b0111, 0b1000, 0b001), + kS1E3W = encode(0b110, 0b0111, 0b1000, 0b001), + kS1E0R = encode(0b000, 0b0111, 0b1000, 0b010), + kS1E0W = encode(0b000, 0b0111, 0b1000, 0b011), + kS12E1R = encode(0b100, 0b0111, 0b1000, 0b100), + kS12E1W = encode(0b100, 0b0111, 0b1000, 0b101), + kS12E0R = encode(0b100, 0b0111, 0b1000, 0b110), + kS12E0W = encode(0b100, 0b0111, 0b1000, 0b111), + kS1E1RP = encode(0b000, 0b0111, 0b1001, 0b000), + kS1E1WP = encode(0b000, 0b0111, 0b1001, 0b001) + }; +} + +//! Data barrier options (DMB/DSB). +namespace DB { + //! Data barrier immediate values. + enum Value : uint32_t { + //! Waits only for loads to complete, and only applies to the outer shareable domain. + kOSHLD = 0x01u, + //! Waits only for stores to complete, and only applies to the outer shareable domain. + kOSHST = 0x02u, + //! Only applies to the outer shareable domain. + kOSH = 0x03u, + + //! Waits only for loads to complete and only applies out to the point of unification. + kNSHLD = 0x05u, + //! Waits only for stores to complete and only applies out to the point of unification. + kNSHST = 0x06u, + //! Only applies out to the point of unification. + kNSH = 0x07u, + + //! Waits only for loads to complete, and only applies to the inner shareable domain. + kISHLD = 0x09u, + //! Waits only for stores to complete, and only applies to the inner shareable domain. + kISHST = 0x0Au, + //! Only applies to the inner shareable domain. + kISH = 0x0Bu, + + //! Waits only for loads to complete. + kLD = 0x0Du, + //! Waits only for stores to complete. + kST = 0x0Eu, + //! Full system memory barrier operation. + kSY = 0x0Fu + }; +} + +//! Data cache maintenance options. +namespace DC { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + //! Data cache maintenance immediate values. + enum Value : uint32_t { + kZVA = encode(0b011, 0b0111, 0b0100, 0b001), + kIVAC = encode(0b000, 0b0111, 0b0110, 0b001), + kISW = encode(0b000, 0b0111, 0b0110, 0b010), + kCVAC = encode(0b011, 0b0111, 0b1010, 0b001), + kCSW = encode(0b000, 0b0111, 0b1010, 0b010), + kCVAU = encode(0b011, 0b0111, 0b1011, 0b001), + kCIVAC = encode(0b011, 0b0111, 0b1110, 0b001), + kCISW = encode(0b000, 0b0111, 0b1110, 0b010), + kCVAP = encode(0b011, 0b0111, 0b1100, 0b001), + kCVADP = encode(0b011, 0b0111, 0b1101, 0b001), + kIGVAC = encode(0b000, 0b0111, 0b0110, 0b011), + kIGSW = encode(0b000, 0b0111, 0b0110, 0b100), + kCGSW = encode(0b000, 0b0111, 0b1010, 0b100), + kCIGSW = encode(0b000, 0b0111, 0b1110, 0b100), + kCGVAC = encode(0b011, 0b0111, 0b1010, 0b011), + kCGVAP = encode(0b011, 0b0111, 0b1100, 0b011), + kCGVADP = encode(0b011, 0b0111, 0b1101, 0b011), + kCIGVAC = encode(0b011, 0b0111, 0b1110, 0b011), + kGVA = encode(0b011, 0b0111, 0b0100, 0b011), + kIGDVAC = encode(0b000, 0b0111, 0b0110, 0b101), + kIGDSW = encode(0b000, 0b0111, 0b0110, 0b110), + kCGDSW = encode(0b000, 0b0111, 0b1010, 0b110), + kCIGDSW = encode(0b000, 0b0111, 0b1110, 0b110), + kCGDVAC = encode(0b011, 0b0111, 0b1010, 0b101), + kCGDVAP = encode(0b011, 0b0111, 0b1100, 0b101), + kCGDVADP = encode(0b011, 0b0111, 0b1101, 0b101), + kCIGDVAC = encode(0b011, 0b0111, 0b1110, 0b101), + kGZVA = encode(0b011, 0b0111, 0b0100, 0b100) + }; +} + +//! Instruction cache maintenance options. +namespace IC { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + //! Instruction cache maintenance immediate values. + enum Value : uint32_t { + kIALLUIS = encode(0b000, 0b0111, 0b0001, 0b000), + kIALLU = encode(0b000, 0b0111, 0b0101, 0b000), + kIVAU = encode(0b011, 0b0111, 0b0101, 0b001) + }; +} + +//! Instruction-fetch barrier options. +namespace ISB { + //! Instruction-fetch barrier immediate values. + enum Value : uint32_t { + kSY = 0xF + }; +} + +//! Prefetch options. +namespace PRFOp { + //! Prefetch immediate values. + enum Value : uint32_t { + kPLDL1KEEP = 0x00, + kPLDL1STRM = 0x01, + kPLDL2KEEP = 0x02, + kPLDL2STRM = 0x03, + kPLDL3KEEP = 0x04, + kPLDL3STRM = 0x05, + kPLIL1KEEP = 0x08, + kPLIL1STRM = 0x09, + kPLIL2KEEP = 0x0A, + kPLIL2STRM = 0x0B, + kPLIL3KEEP = 0x0C, + kPLIL3STRM = 0x0D, + kPSTL1KEEP = 0x10, + kPSTL1STRM = 0x11, + kPSTL2KEEP = 0x12, + kPSTL2STRM = 0x13, + kPSTL3KEEP = 0x14, + kPSTL3STRM = 0x15 + }; +} + +//! PSB instruction options. +namespace PSB { + //! PSB immediate values. + enum Value : uint32_t { + kCSYNC = 0x11u + }; +} + +namespace TLBI { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + enum Value : uint32_t { + kIPAS2E1IS = encode(0b100, 0b1000, 0b0000, 0b001), + kIPAS2LE1IS = encode(0b100, 0b1000, 0b0000, 0b101), + kVMALLE1IS = encode(0b000, 0b1000, 0b0011, 0b000), + kALLE2IS = encode(0b100, 0b1000, 0b0011, 0b000), + kALLE3IS = encode(0b110, 0b1000, 0b0011, 0b000), + kVAE1IS = encode(0b000, 0b1000, 0b0011, 0b001), + kVAE2IS = encode(0b100, 0b1000, 0b0011, 0b001), + kVAE3IS = encode(0b110, 0b1000, 0b0011, 0b001), + kASIDE1IS = encode(0b000, 0b1000, 0b0011, 0b010), + kVAAE1IS = encode(0b000, 0b1000, 0b0011, 0b011), + kALLE1IS = encode(0b100, 0b1000, 0b0011, 0b100), + kVALE1IS = encode(0b000, 0b1000, 0b0011, 0b101), + kVALE2IS = encode(0b100, 0b1000, 0b0011, 0b101), + kVALE3IS = encode(0b110, 0b1000, 0b0011, 0b101), + kVMALLS12E1IS = encode(0b100, 0b1000, 0b0011, 0b110), + kVAALE1IS = encode(0b000, 0b1000, 0b0011, 0b111), + kIPAS2E1 = encode(0b100, 0b1000, 0b0100, 0b001), + kIPAS2LE1 = encode(0b100, 0b1000, 0b0100, 0b101), + kVMALLE1 = encode(0b000, 0b1000, 0b0111, 0b000), + kALLE2 = encode(0b100, 0b1000, 0b0111, 0b000), + kALLE3 = encode(0b110, 0b1000, 0b0111, 0b000), + kVAE1 = encode(0b000, 0b1000, 0b0111, 0b001), + kVAE2 = encode(0b100, 0b1000, 0b0111, 0b001), + kVAE3 = encode(0b110, 0b1000, 0b0111, 0b001), + kASIDE1 = encode(0b000, 0b1000, 0b0111, 0b010), + kVAAE1 = encode(0b000, 0b1000, 0b0111, 0b011), + kALLE1 = encode(0b100, 0b1000, 0b0111, 0b100), + kVALE1 = encode(0b000, 0b1000, 0b0111, 0b101), + kVALE2 = encode(0b100, 0b1000, 0b0111, 0b101), + kVALE3 = encode(0b110, 0b1000, 0b0111, 0b101), + kVMALLS12E1 = encode(0b100, 0b1000, 0b0111, 0b110), + kVAALE1 = encode(0b000, 0b1000, 0b0111, 0b111), + + kVMALLE1OS = encode(0b000, 0b1000, 0b0001, 0b000), + kVAE1OS = encode(0b000, 0b1000, 0b0001, 0b001), + kASIDE1OS = encode(0b000, 0b1000, 0b0001, 0b010), + kVAAE1OS = encode(0b000, 0b1000, 0b0001, 0b011), + kVALE1OS = encode(0b000, 0b1000, 0b0001, 0b101), + kVAALE1OS = encode(0b000, 0b1000, 0b0001, 0b111), + kIPAS2E1OS = encode(0b100, 0b1000, 0b0100, 0b000), + kIPAS2LE1OS = encode(0b100, 0b1000, 0b0100, 0b100), + kVAE2OS = encode(0b100, 0b1000, 0b0001, 0b001), + kVALE2OS = encode(0b100, 0b1000, 0b0001, 0b101), + kVMALLS12E1OS = encode(0b100, 0b1000, 0b0001, 0b110), + kVAE3OS = encode(0b110, 0b1000, 0b0001, 0b001), + kVALE3OS = encode(0b110, 0b1000, 0b0001, 0b101), + kALLE2OS = encode(0b100, 0b1000, 0b0001, 0b000), + kALLE1OS = encode(0b100, 0b1000, 0b0001, 0b100), + kALLE3OS = encode(0b110, 0b1000, 0b0001, 0b000), + + kRVAE1 = encode(0b000, 0b1000, 0b0110, 0b001), + kRVAAE1 = encode(0b000, 0b1000, 0b0110, 0b011), + kRVALE1 = encode(0b000, 0b1000, 0b0110, 0b101), + kRVAALE1 = encode(0b000, 0b1000, 0b0110, 0b111), + kRVAE1IS = encode(0b000, 0b1000, 0b0010, 0b001), + kRVAAE1IS = encode(0b000, 0b1000, 0b0010, 0b011), + kRVALE1IS = encode(0b000, 0b1000, 0b0010, 0b101), + kRVAALE1IS = encode(0b000, 0b1000, 0b0010, 0b111), + kRVAE1OS = encode(0b000, 0b1000, 0b0101, 0b001), + kRVAAE1OS = encode(0b000, 0b1000, 0b0101, 0b011), + kRVALE1OS = encode(0b000, 0b1000, 0b0101, 0b101), + kRVAALE1OS = encode(0b000, 0b1000, 0b0101, 0b111), + kRIPAS2E1IS = encode(0b100, 0b1000, 0b0000, 0b010), + kRIPAS2LE1IS = encode(0b100, 0b1000, 0b0000, 0b110), + kRIPAS2E1 = encode(0b100, 0b1000, 0b0100, 0b010), + kRIPAS2LE1 = encode(0b100, 0b1000, 0b0100, 0b110), + kRIPAS2E1OS = encode(0b100, 0b1000, 0b0100, 0b011), + kRIPAS2LE1OS = encode(0b100, 0b1000, 0b0100, 0b111), + kRVAE2 = encode(0b100, 0b1000, 0b0110, 0b001), + kRVALE2 = encode(0b100, 0b1000, 0b0110, 0b101), + kRVAE2IS = encode(0b100, 0b1000, 0b0010, 0b001), + kRVALE2IS = encode(0b100, 0b1000, 0b0010, 0b101), + kRVAE2OS = encode(0b100, 0b1000, 0b0101, 0b001), + kRVALE2OS = encode(0b100, 0b1000, 0b0101, 0b101), + kRVAE3 = encode(0b110, 0b1000, 0b0110, 0b001), + kRVALE3 = encode(0b110, 0b1000, 0b0110, 0b101), + kRVAE3IS = encode(0b110, 0b1000, 0b0010, 0b001), + kRVALE3IS = encode(0b110, 0b1000, 0b0010, 0b101), + kRVAE3OS = encode(0b110, 0b1000, 0b0101, 0b001), + kRVALE3OS = encode(0b110, 0b1000, 0b0101, 0b101), + }; +} + +//! Trace synchronization barrier options. +namespace TSB { + //! Trace synchronization immediate values. + enum Value : uint32_t { + kCSYNC = 0 + }; +} + +//! Processor state access through MSR. +namespace PState { + //! Encodes a pstate from `op0` and `op1`. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op0, uint32_t op1) noexcept { + return (op0 << 3) | (op1 << 0); + } + + //! Processor state access immediates. + enum Value : uint32_t { + kSPSel = encode(0b000, 0b101), + kDAIFSet = encode(0b011, 0b110), + kDAIFClr = encode(0b011, 0b111), + kPAN = encode(0b000, 0b100), + kUAO = encode(0b000, 0b011), + kDIT = encode(0b011, 0b010), + kSSBS = encode(0b011, 0b001), + kTCO = encode(0b011, 0b100) + }; +}; + +//! System register identifiers and utilities (MSR/MRS). +namespace SysReg { + //! System register fields. + struct Fields { + uint8_t op0; + uint8_t op1; + uint8_t cRn; + uint8_t cRm; + uint8_t op2; + }; + + //! Encodes a system register from `op0`, `op1`, `cRn`, `cRm`, and `op2` fields. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op0, uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op0 << 14) | (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + //! Encodes a system register from `fields`. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(const Fields& fields) noexcept { + return encode(fields.op0, fields.op1, fields.cRn, fields.cRm, fields.op2); + } + + //! Decodes a system register to \ref Fields. + static ASMJIT_INLINE_NODEBUG constexpr Fields decode(uint32_t id) noexcept { + return Fields { + uint8_t((id >> 14) & 0x3u), + uint8_t((id >> 11) & 0x7u), + uint8_t((id >> 7) & 0xFu), + uint8_t((id >> 3) & 0xFu), + uint8_t((id >> 0) & 0x7u) + }; + } + + //! System register identifiers. + enum Id : uint32_t { + kACTLR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b001), // RW + kACTLR_EL2 = encode(0b11, 0b100, 0b0001, 0b0000, 0b001), // RW + kACTLR_EL3 = encode(0b11, 0b110, 0b0001, 0b0000, 0b001), // RW + kAFSR0_EL1 = encode(0b11, 0b000, 0b0101, 0b0001, 0b000), // RW + kAFSR0_EL12 = encode(0b11, 0b101, 0b0101, 0b0001, 0b000), // RW + kAFSR0_EL2 = encode(0b11, 0b100, 0b0101, 0b0001, 0b000), // RW + kAFSR0_EL3 = encode(0b11, 0b110, 0b0101, 0b0001, 0b000), // RW + kAFSR1_EL1 = encode(0b11, 0b000, 0b0101, 0b0001, 0b001), // RW + kAFSR1_EL12 = encode(0b11, 0b101, 0b0101, 0b0001, 0b001), // RW + kAFSR1_EL2 = encode(0b11, 0b100, 0b0101, 0b0001, 0b001), // RW + kAFSR1_EL3 = encode(0b11, 0b110, 0b0101, 0b0001, 0b001), // RW + kAIDR_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b111), // RO + kAMAIR_EL1 = encode(0b11, 0b000, 0b1010, 0b0011, 0b000), // RW + kAMAIR_EL12 = encode(0b11, 0b101, 0b1010, 0b0011, 0b000), // RW + kAMAIR_EL2 = encode(0b11, 0b100, 0b1010, 0b0011, 0b000), // RW + kAMAIR_EL3 = encode(0b11, 0b110, 0b1010, 0b0011, 0b000), // RW + kAMCFGR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b001), // RO + kAMCGCR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b010), // RO + kAMCNTENCLR0_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b100), // RW + kAMCNTENCLR1_EL0 = encode(0b11, 0b011, 0b1101, 0b0011, 0b000), // RW + kAMCNTENSET0_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b101), // RW + kAMCNTENSET1_EL0 = encode(0b11, 0b011, 0b1101, 0b0011, 0b001), // RW + kAMCR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b000), // RW + kAMEVCNTR00_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b000), // RW + kAMEVCNTR01_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b001), // RW + kAMEVCNTR02_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b010), // RW + kAMEVCNTR03_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b011), // RW + kAMEVCNTR10_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b000), // RW + kAMEVCNTR110_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b010), // RW + kAMEVCNTR111_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b011), // RW + kAMEVCNTR112_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b100), // RW + kAMEVCNTR113_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b101), // RW + kAMEVCNTR114_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b110), // RW + kAMEVCNTR115_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b111), // RW + kAMEVCNTR11_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b001), // RW + kAMEVCNTR12_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b010), // RW + kAMEVCNTR13_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b011), // RW + kAMEVCNTR14_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b100), // RW + kAMEVCNTR15_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b101), // RW + kAMEVCNTR16_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b110), // RW + kAMEVCNTR17_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b111), // RW + kAMEVCNTR18_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b000), // RW + kAMEVCNTR19_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b001), // RW + kAMEVTYPER00_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b000), // RO + kAMEVTYPER01_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b001), // RO + kAMEVTYPER02_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b010), // RO + kAMEVTYPER03_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b011), // RO + kAMEVTYPER10_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b000), // RW + kAMEVTYPER110_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b010), // RW + kAMEVTYPER111_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b011), // RW + kAMEVTYPER112_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b100), // RW + kAMEVTYPER113_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b101), // RW + kAMEVTYPER114_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b110), // RW + kAMEVTYPER115_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b111), // RW + kAMEVTYPER11_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b001), // RW + kAMEVTYPER12_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b010), // RW + kAMEVTYPER13_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b011), // RW + kAMEVTYPER14_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b100), // RW + kAMEVTYPER15_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b101), // RW + kAMEVTYPER16_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b110), // RW + kAMEVTYPER17_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b111), // RW + kAMEVTYPER18_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b000), // RW + kAMEVTYPER19_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b001), // RW + kAMUSERENR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b011), // RW + kAPDAKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b001), // RW + kAPDAKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b000), // RW + kAPDBKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b011), // RW + kAPDBKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b010), // RW + kAPGAKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0011, 0b001), // RW + kAPGAKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0011, 0b000), // RW + kAPIAKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b001), // RW + kAPIAKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b000), // RW + kAPIBKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b011), // RW + kAPIBKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b010), // RW + kCCSIDR2_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b010), // RO + kCCSIDR_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b000), // RO + kCLIDR_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b001), // RO + kCNTFRQ_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b000), // RW + kCNTHCTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0001, 0b000), // RW + kCNTHPS_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0101, 0b001), // RW + kCNTHPS_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0101, 0b010), // RW + kCNTHPS_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0101, 0b000), // RW + kCNTHP_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0010, 0b001), // RW + kCNTHP_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0010, 0b010), // RW + kCNTHP_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0010, 0b000), // RW + kCNTHVS_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0100, 0b001), // RW + kCNTHVS_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0100, 0b010), // RW + kCNTHVS_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0100, 0b000), // RW + kCNTHV_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0011, 0b001), // RW + kCNTHV_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0011, 0b010), // RW + kCNTHV_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0011, 0b000), // RW + kCNTISCALE_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b101), // RW + kCNTKCTL_EL1 = encode(0b11, 0b000, 0b1110, 0b0001, 0b000), // RW + kCNTKCTL_EL12 = encode(0b11, 0b101, 0b1110, 0b0001, 0b000), // RW + kCNTPCTSS_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b101), // RW + kCNTPCT_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b001), // RO + kCNTPOFF_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b110), // RW + kCNTPS_CTL_EL1 = encode(0b11, 0b111, 0b1110, 0b0010, 0b001), // RW + kCNTPS_CVAL_EL1 = encode(0b11, 0b111, 0b1110, 0b0010, 0b010), // RW + kCNTPS_TVAL_EL1 = encode(0b11, 0b111, 0b1110, 0b0010, 0b000), // RW + kCNTP_CTL_EL0 = encode(0b11, 0b011, 0b1110, 0b0010, 0b001), // RW + kCNTP_CTL_EL02 = encode(0b11, 0b101, 0b1110, 0b0010, 0b001), // RW + kCNTP_CVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0010, 0b010), // RW + kCNTP_CVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0010, 0b010), // RW + kCNTP_TVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0010, 0b000), // RW + kCNTP_TVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0010, 0b000), // RW + kCNTSCALE_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b100), // RW + kCNTVCTSS_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b110), // RW + kCNTVCT_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b010), // RO + kCNTVFRQ_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b111), // RW + kCNTVOFF_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b011), // RW + kCNTV_CTL_EL0 = encode(0b11, 0b011, 0b1110, 0b0011, 0b001), // RW + kCNTV_CTL_EL02 = encode(0b11, 0b101, 0b1110, 0b0011, 0b001), // RW + kCNTV_CVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0011, 0b010), // RW + kCNTV_CVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0011, 0b010), // RW + kCNTV_TVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0011, 0b000), // RW + kCNTV_TVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0011, 0b000), // RW + kCONTEXTIDR_EL1 = encode(0b11, 0b000, 0b1101, 0b0000, 0b001), // RW + kCONTEXTIDR_EL12 = encode(0b11, 0b101, 0b1101, 0b0000, 0b001), // RW + kCONTEXTIDR_EL2 = encode(0b11, 0b100, 0b1101, 0b0000, 0b001), // RW + kCPACR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b010), // RW + kCPACR_EL12 = encode(0b11, 0b101, 0b0001, 0b0000, 0b010), // RW + kCPM_IOACC_CTL_EL3 = encode(0b11, 0b111, 0b1111, 0b0010, 0b000), // RW + kCPTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b010), // RW + kCPTR_EL3 = encode(0b11, 0b110, 0b0001, 0b0001, 0b010), // RW + kCSSELR_EL1 = encode(0b11, 0b010, 0b0000, 0b0000, 0b000), // RW + kCTR_EL0 = encode(0b11, 0b011, 0b0000, 0b0000, 0b001), // RO + kCurrentEL = encode(0b11, 0b000, 0b0100, 0b0010, 0b010), // RO + kDACR32_EL2 = encode(0b11, 0b100, 0b0011, 0b0000, 0b000), // RW + kDAIF = encode(0b11, 0b011, 0b0100, 0b0010, 0b001), // RW + kDBGAUTHSTATUS_EL1 = encode(0b10, 0b000, 0b0111, 0b1110, 0b110), // RO + kDBGBCR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b101), // RW + kDBGBCR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b101), // RW + kDBGBCR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b101), // RW + kDBGBCR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b101), // RW + kDBGBCR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b101), // RW + kDBGBCR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b101), // RW + kDBGBCR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b101), // RW + kDBGBCR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b101), // RW + kDBGBCR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b101), // RW + kDBGBCR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b101), // RW + kDBGBCR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b101), // RW + kDBGBCR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b101), // RW + kDBGBCR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b101), // RW + kDBGBCR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b101), // RW + kDBGBCR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b101), // RW + kDBGBCR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b101), // RW + kDBGBVR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b100), // RW + kDBGBVR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b100), // RW + kDBGBVR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b100), // RW + kDBGBVR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b100), // RW + kDBGBVR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b100), // RW + kDBGBVR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b100), // RW + kDBGBVR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b100), // RW + kDBGBVR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b100), // RW + kDBGBVR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b100), // RW + kDBGBVR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b100), // RW + kDBGBVR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b100), // RW + kDBGBVR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b100), // RW + kDBGBVR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b100), // RW + kDBGBVR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b100), // RW + kDBGBVR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b100), // RW + kDBGBVR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b100), // RW + kDBGCLAIMCLR_EL1 = encode(0b10, 0b000, 0b0111, 0b1001, 0b110), // RW + kDBGCLAIMSET_EL1 = encode(0b10, 0b000, 0b0111, 0b1000, 0b110), // RW + kDBGDTRRX_EL0 = encode(0b10, 0b011, 0b0000, 0b0101, 0b000), // RO + kDBGDTRTX_EL0 = encode(0b10, 0b011, 0b0000, 0b0101, 0b000), // WO + kDBGDTR_EL0 = encode(0b10, 0b011, 0b0000, 0b0100, 0b000), // RW + kDBGPRCR_EL1 = encode(0b10, 0b000, 0b0001, 0b0100, 0b100), // RW + kDBGVCR32_EL2 = encode(0b10, 0b100, 0b0000, 0b0111, 0b000), // RW + kDBGWCR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b111), // RW + kDBGWCR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b111), // RW + kDBGWCR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b111), // RW + kDBGWCR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b111), // RW + kDBGWCR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b111), // RW + kDBGWCR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b111), // RW + kDBGWCR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b111), // RW + kDBGWCR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b111), // RW + kDBGWCR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b111), // RW + kDBGWCR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b111), // RW + kDBGWCR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b111), // RW + kDBGWCR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b111), // RW + kDBGWCR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b111), // RW + kDBGWCR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b111), // RW + kDBGWCR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b111), // RW + kDBGWCR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b111), // RW + kDBGWVR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b110), // RW + kDBGWVR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b110), // RW + kDBGWVR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b110), // RW + kDBGWVR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b110), // RW + kDBGWVR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b110), // RW + kDBGWVR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b110), // RW + kDBGWVR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b110), // RW + kDBGWVR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b110), // RW + kDBGWVR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b110), // RW + kDBGWVR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b110), // RW + kDBGWVR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b110), // RW + kDBGWVR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b110), // RW + kDBGWVR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b110), // RW + kDBGWVR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b110), // RW + kDBGWVR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b110), // RW + kDBGWVR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b110), // RW + kDCZID_EL0 = encode(0b11, 0b011, 0b0000, 0b0000, 0b111), // RO + kDISR_EL1 = encode(0b11, 0b000, 0b1100, 0b0001, 0b001), // RW + kDIT = encode(0b11, 0b011, 0b0100, 0b0010, 0b101), // RW + kDLR_EL0 = encode(0b11, 0b011, 0b0100, 0b0101, 0b001), // RW + kDSPSR_EL0 = encode(0b11, 0b011, 0b0100, 0b0101, 0b000), // RW + kELR_EL1 = encode(0b11, 0b000, 0b0100, 0b0000, 0b001), // RW + kELR_EL12 = encode(0b11, 0b101, 0b0100, 0b0000, 0b001), // RW + kELR_EL2 = encode(0b11, 0b100, 0b0100, 0b0000, 0b001), // RW + kELR_EL3 = encode(0b11, 0b110, 0b0100, 0b0000, 0b001), // RW + kERRIDR_EL1 = encode(0b11, 0b000, 0b0101, 0b0011, 0b000), // RO + kERRSELR_EL1 = encode(0b11, 0b000, 0b0101, 0b0011, 0b001), // RW + kERXADDR_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b011), // RW + kERXCTLR_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b001), // RW + kERXFR_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b000), // RO + kERXMISC0_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b000), // RW + kERXMISC1_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b001), // RW + kERXMISC2_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b010), // RW + kERXMISC3_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b011), // RW + kERXPFGCDN_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b110), // RW + kERXPFGCTL_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b101), // RW + kERXPFGF_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b100), // RO + kERXSTATUS_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b010), // RW + kESR_EL1 = encode(0b11, 0b000, 0b0101, 0b0010, 0b000), // RW + kESR_EL12 = encode(0b11, 0b101, 0b0101, 0b0010, 0b000), // RW + kESR_EL2 = encode(0b11, 0b100, 0b0101, 0b0010, 0b000), // RW + kESR_EL3 = encode(0b11, 0b110, 0b0101, 0b0010, 0b000), // RW + kFAR_EL1 = encode(0b11, 0b000, 0b0110, 0b0000, 0b000), // RW + kFAR_EL12 = encode(0b11, 0b101, 0b0110, 0b0000, 0b000), // RW + kFAR_EL2 = encode(0b11, 0b100, 0b0110, 0b0000, 0b000), // RW + kFAR_EL3 = encode(0b11, 0b110, 0b0110, 0b0000, 0b000), // RW + kFPCR = encode(0b11, 0b011, 0b0100, 0b0100, 0b000), // RW + kFPEXC32_EL2 = encode(0b11, 0b100, 0b0101, 0b0011, 0b000), // RW + kFPSR = encode(0b11, 0b011, 0b0100, 0b0100, 0b001), // RW + kGCR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b110), // RW + kGMID_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b100), // RO + kHACR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b111), // RW + kHCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b000), // RW + kHDFGRTR_EL2 = encode(0b11, 0b100, 0b0011, 0b0001, 0b100), // RW + kHDFGWTR_EL2 = encode(0b11, 0b100, 0b0011, 0b0001, 0b101), // RW + kHFGITR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b110), // RW + kHFGRTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b100), // RW + kHFGWTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b101), // RW + kHPFAR_EL2 = encode(0b11, 0b100, 0b0110, 0b0000, 0b100), // RW + kHSTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b011), // RW + kICC_AP0R0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b100), // RW + kICC_AP0R1_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b101), // RW + kICC_AP0R2_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b110), // RW + kICC_AP0R3_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b111), // RW + kICC_AP1R0_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b000), // RW + kICC_AP1R1_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b001), // RW + kICC_AP1R2_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b010), // RW + kICC_AP1R3_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b011), // RW + kICC_ASGI1R_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b110), // WO + kICC_BPR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b011), // RW + kICC_BPR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b011), // RW + kICC_CTLR_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b100), // RW + kICC_CTLR_EL3 = encode(0b11, 0b110, 0b1100, 0b1100, 0b100), // RW + kICC_DIR_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b001), // WO + kICC_EOIR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b001), // WO + kICC_EOIR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b001), // WO + kICC_HPPIR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b010), // RO + kICC_HPPIR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b010), // RO + kICC_IAR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b000), // RO + kICC_IAR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b000), // RO + kICC_IGRPEN0_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b110), // RW + kICC_IGRPEN1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b111), // RW + kICC_IGRPEN1_EL3 = encode(0b11, 0b110, 0b1100, 0b1100, 0b111), // RW + kICC_PMR_EL1 = encode(0b11, 0b000, 0b0100, 0b0110, 0b000), // RW + kICC_RPR_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b011), // RO + kICC_SGI0R_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b111), // WO + kICC_SGI1R_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b101), // WO + kICC_SRE_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b101), // RW + kICC_SRE_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b101), // RW + kICC_SRE_EL3 = encode(0b11, 0b110, 0b1100, 0b1100, 0b101), // RW + kICH_AP0R0_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b000), // RW + kICH_AP0R1_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b001), // RW + kICH_AP0R2_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b010), // RW + kICH_AP0R3_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b011), // RW + kICH_AP1R0_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b000), // RW + kICH_AP1R1_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b001), // RW + kICH_AP1R2_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b010), // RW + kICH_AP1R3_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b011), // RW + kICH_EISR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b011), // RO + kICH_ELRSR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b101), // RO + kICH_HCR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b000), // RW + kICH_LR0_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b000), // RW + kICH_LR10_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b010), // RW + kICH_LR11_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b011), // RW + kICH_LR12_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b100), // RW + kICH_LR13_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b101), // RW + kICH_LR14_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b110), // RW + kICH_LR15_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b111), // RW + kICH_LR1_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b001), // RW + kICH_LR2_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b010), // RW + kICH_LR3_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b011), // RW + kICH_LR4_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b100), // RW + kICH_LR5_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b101), // RW + kICH_LR6_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b110), // RW + kICH_LR7_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b111), // RW + kICH_LR8_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b000), // RW + kICH_LR9_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b001), // RW + kICH_MISR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b010), // RO + kICH_VMCR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b111), // RW + kICH_VTR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b001), // RO + kID_AA64AFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b100), // RO + kID_AA64AFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b101), // RO + kID_AA64DFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b000), // RO + kID_AA64DFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b001), // RO + kID_AA64ISAR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0110, 0b000), // RO + kID_AA64ISAR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0110, 0b001), // RO + kID_AA64ISAR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0110, 0b010), // RO + kID_AA64MMFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b000), // RO + kID_AA64MMFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b001), // RO + kID_AA64MMFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b010), // RO + kID_AA64MMFR3_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b011), // RO + kID_AA64MMFR4_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b100), // RO + kID_AA64PFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0100, 0b000), // RO + kID_AA64PFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0100, 0b001), // RO + kID_AA64ZFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0100, 0b100), // RO + kID_AFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b011), // RO + kID_DFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b010), // RO + kID_ISAR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b000), // RO + kID_ISAR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b001), // RO + kID_ISAR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b010), // RO + kID_ISAR3_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b011), // RO + kID_ISAR4_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b100), // RO + kID_ISAR5_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b101), // RO + kID_ISAR6_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b111), // RO + kID_MMFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b100), // RO + kID_MMFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b101), // RO + kID_MMFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b110), // RO + kID_MMFR3_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b111), // RO + kID_MMFR4_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b110), // RO + kID_MMFR5_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b110), // RO + kID_PFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b000), // RO + kID_PFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b001), // RO + kID_PFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b100), // RO + kIFSR32_EL2 = encode(0b11, 0b100, 0b0101, 0b0000, 0b001), // RW + kISR_EL1 = encode(0b11, 0b000, 0b1100, 0b0001, 0b000), // RO + kLORC_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b011), // RW + kLOREA_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b001), // RW + kLORID_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b111), // RO + kLORN_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b010), // RW + kLORSA_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b000), // RW + kMAIR_EL1 = encode(0b11, 0b000, 0b1010, 0b0010, 0b000), // RW + kMAIR_EL12 = encode(0b11, 0b101, 0b1010, 0b0010, 0b000), // RW + kMAIR_EL2 = encode(0b11, 0b100, 0b1010, 0b0010, 0b000), // RW + kMAIR_EL3 = encode(0b11, 0b110, 0b1010, 0b0010, 0b000), // RW + kMDCCINT_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b000), // RW + kMDCCSR_EL0 = encode(0b10, 0b011, 0b0000, 0b0001, 0b000), // RO + kMDCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b001), // RW + kMDCR_EL3 = encode(0b11, 0b110, 0b0001, 0b0011, 0b001), // RW + kMDRAR_EL1 = encode(0b10, 0b000, 0b0001, 0b0000, 0b000), // RO + kMDSCR_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b010), // RW + kMIDR_EL1 = encode(0b11, 0b000, 0b0000, 0b0000, 0b000), // RO + kMPAM0_EL1 = encode(0b11, 0b000, 0b1010, 0b0101, 0b001), // RW + kMPAM1_EL1 = encode(0b11, 0b000, 0b1010, 0b0101, 0b000), // RW + kMPAM1_EL12 = encode(0b11, 0b101, 0b1010, 0b0101, 0b000), // RW + kMPAM2_EL2 = encode(0b11, 0b100, 0b1010, 0b0101, 0b000), // RW + kMPAM3_EL3 = encode(0b11, 0b110, 0b1010, 0b0101, 0b000), // RW + kMPAMHCR_EL2 = encode(0b11, 0b100, 0b1010, 0b0100, 0b000), // RW + kMPAMIDR_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b100), // RO + kMPAMVPM0_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b000), // RW + kMPAMVPM1_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b001), // RW + kMPAMVPM2_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b010), // RW + kMPAMVPM3_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b011), // RW + kMPAMVPM4_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b100), // RW + kMPAMVPM5_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b101), // RW + kMPAMVPM6_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b110), // RW + kMPAMVPM7_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b111), // RW + kMPAMVPMV_EL2 = encode(0b11, 0b100, 0b1010, 0b0100, 0b001), // RW + kMPIDR_EL1 = encode(0b11, 0b000, 0b0000, 0b0000, 0b101), // RO + kMVFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b000), // RO + kMVFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b001), // RO + kMVFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b010), // RO + kNZCV = encode(0b11, 0b011, 0b0100, 0b0010, 0b000), // RW + kOSDLR_EL1 = encode(0b10, 0b000, 0b0001, 0b0011, 0b100), // RW + kOSDTRRX_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b010), // RW + kOSDTRTX_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b010), // RW + kOSECCR_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b010), // RW + kOSLAR_EL1 = encode(0b10, 0b000, 0b0001, 0b0000, 0b100), // WO + kOSLSR_EL1 = encode(0b10, 0b000, 0b0001, 0b0001, 0b100), // RO + kPAN = encode(0b11, 0b000, 0b0100, 0b0010, 0b011), // RW + kPAR_EL1 = encode(0b11, 0b000, 0b0111, 0b0100, 0b000), // RW + kPMBIDR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b111), // RO + kPMBLIMITR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b000), // RW + kPMBPTR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b001), // RW + kPMBSR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b011), // RW + kPMCCFILTR_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b111), // RW + kPMCCNTR_EL0 = encode(0b11, 0b011, 0b1001, 0b1101, 0b000), // RW + kPMCEID0_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b110), // RO + kPMCEID1_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b111), // RO + kPMCNTENCLR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b010), // RW + kPMCNTENSET_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b001), // RW + kPMCR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b000), // RW + kPMEVCNTR0_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b000), // RW + kPMEVCNTR10_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b010), // RW + kPMEVCNTR11_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b011), // RW + kPMEVCNTR12_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b100), // RW + kPMEVCNTR13_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b101), // RW + kPMEVCNTR14_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b110), // RW + kPMEVCNTR15_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b111), // RW + kPMEVCNTR16_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b000), // RW + kPMEVCNTR17_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b001), // RW + kPMEVCNTR18_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b010), // RW + kPMEVCNTR19_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b011), // RW + kPMEVCNTR1_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b001), // RW + kPMEVCNTR20_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b100), // RW + kPMEVCNTR21_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b101), // RW + kPMEVCNTR22_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b110), // RW + kPMEVCNTR23_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b111), // RW + kPMEVCNTR24_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b000), // RW + kPMEVCNTR25_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b001), // RW + kPMEVCNTR26_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b010), // RW + kPMEVCNTR27_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b011), // RW + kPMEVCNTR28_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b100), // RW + kPMEVCNTR29_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b101), // RW + kPMEVCNTR2_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b010), // RW + kPMEVCNTR30_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b110), // RW + kPMEVCNTR3_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b011), // RW + kPMEVCNTR4_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b100), // RW + kPMEVCNTR5_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b101), // RW + kPMEVCNTR6_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b110), // RW + kPMEVCNTR7_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b111), // RW + kPMEVCNTR8_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b000), // RW + kPMEVCNTR9_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b001), // RW + kPMEVTYPER0_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b000), // RW + kPMEVTYPER10_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b010), // RW + kPMEVTYPER11_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b011), // RW + kPMEVTYPER12_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b100), // RW + kPMEVTYPER13_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b101), // RW + kPMEVTYPER14_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b110), // RW + kPMEVTYPER15_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b111), // RW + kPMEVTYPER16_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b000), // RW + kPMEVTYPER17_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b001), // RW + kPMEVTYPER18_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b010), // RW + kPMEVTYPER19_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b011), // RW + kPMEVTYPER1_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b001), // RW + kPMEVTYPER20_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b100), // RW + kPMEVTYPER21_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b101), // RW + kPMEVTYPER22_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b110), // RW + kPMEVTYPER23_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b111), // RW + kPMEVTYPER24_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b000), // RW + kPMEVTYPER25_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b001), // RW + kPMEVTYPER26_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b010), // RW + kPMEVTYPER27_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b011), // RW + kPMEVTYPER28_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b100), // RW + kPMEVTYPER29_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b101), // RW + kPMEVTYPER2_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b010), // RW + kPMEVTYPER30_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b110), // RW + kPMEVTYPER3_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b011), // RW + kPMEVTYPER4_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b100), // RW + kPMEVTYPER5_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b101), // RW + kPMEVTYPER6_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b110), // RW + kPMEVTYPER7_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b111), // RW + kPMEVTYPER8_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b000), // RW + kPMEVTYPER9_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b001), // RW + kPMINTENCLR_EL1 = encode(0b11, 0b000, 0b1001, 0b1110, 0b010), // RW + kPMINTENSET_EL1 = encode(0b11, 0b000, 0b1001, 0b1110, 0b001), // RW + kPMMIR_EL1 = encode(0b11, 0b000, 0b1001, 0b1110, 0b110), // RW + kPMOVSCLR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b011), // RW + kPMOVSSET_EL0 = encode(0b11, 0b011, 0b1001, 0b1110, 0b011), // RW + kPMSCR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b000), // RW + kPMSCR_EL12 = encode(0b11, 0b101, 0b1001, 0b1001, 0b000), // RW + kPMSCR_EL2 = encode(0b11, 0b100, 0b1001, 0b1001, 0b000), // RW + kPMSELR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b101), // RW + kPMSEVFR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b101), // RW + kPMSFCR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b100), // RW + kPMSICR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b010), // RW + kPMSIDR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b111), // RO + kPMSIRR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b011), // RW + kPMSLATFR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b110), // RW + kPMSWINC_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b100), // WO + kPMUSERENR_EL0 = encode(0b11, 0b011, 0b1001, 0b1110, 0b000), // RW + kPMXEVCNTR_EL0 = encode(0b11, 0b011, 0b1001, 0b1101, 0b010), // RW + kPMXEVTYPER_EL0 = encode(0b11, 0b011, 0b1001, 0b1101, 0b001), // RW + kREVIDR_EL1 = encode(0b11, 0b000, 0b0000, 0b0000, 0b110), // RO + kRGSR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b101), // RW + kRMR_EL1 = encode(0b11, 0b000, 0b1100, 0b0000, 0b010), // RW + kRMR_EL2 = encode(0b11, 0b100, 0b1100, 0b0000, 0b010), // RW + kRMR_EL3 = encode(0b11, 0b110, 0b1100, 0b0000, 0b010), // RW + kRNDR = encode(0b11, 0b011, 0b0010, 0b0100, 0b000), // RO + kRNDRRS = encode(0b11, 0b011, 0b0010, 0b0100, 0b001), // RO + kRVBAR_EL1 = encode(0b11, 0b000, 0b1100, 0b0000, 0b001), // RO + kRVBAR_EL2 = encode(0b11, 0b100, 0b1100, 0b0000, 0b001), // RO + kRVBAR_EL3 = encode(0b11, 0b110, 0b1100, 0b0000, 0b001), // RO + kSCR_EL3 = encode(0b11, 0b110, 0b0001, 0b0001, 0b000), // RW + kSCTLR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b000), // RW + kSCTLR_EL12 = encode(0b11, 0b101, 0b0001, 0b0000, 0b000), // RW + kSCTLR_EL2 = encode(0b11, 0b100, 0b0001, 0b0000, 0b000), // RW + kSCTLR_EL3 = encode(0b11, 0b110, 0b0001, 0b0000, 0b000), // RW + kSCXTNUM_EL0 = encode(0b11, 0b011, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL1 = encode(0b11, 0b000, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL12 = encode(0b11, 0b101, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL2 = encode(0b11, 0b100, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL3 = encode(0b11, 0b110, 0b1101, 0b0000, 0b111), // RW + kSDER32_EL2 = encode(0b11, 0b100, 0b0001, 0b0011, 0b001), // RW + kSDER32_EL3 = encode(0b11, 0b110, 0b0001, 0b0001, 0b001), // RW + kSPSR_EL1 = encode(0b11, 0b000, 0b0100, 0b0000, 0b000), // RW + kSPSR_EL12 = encode(0b11, 0b101, 0b0100, 0b0000, 0b000), // RW + kSPSR_EL2 = encode(0b11, 0b100, 0b0100, 0b0000, 0b000), // RW + kSPSR_EL3 = encode(0b11, 0b110, 0b0100, 0b0000, 0b000), // RW + kSPSR_abt = encode(0b11, 0b100, 0b0100, 0b0011, 0b001), // RW + kSPSR_fiq = encode(0b11, 0b100, 0b0100, 0b0011, 0b011), // RW + kSPSR_irq = encode(0b11, 0b100, 0b0100, 0b0011, 0b000), // RW + kSPSR_und = encode(0b11, 0b100, 0b0100, 0b0011, 0b010), // RW + kSPSel = encode(0b11, 0b000, 0b0100, 0b0010, 0b000), // RW + kSP_EL0 = encode(0b11, 0b000, 0b0100, 0b0001, 0b000), // RW + kSP_EL1 = encode(0b11, 0b100, 0b0100, 0b0001, 0b000), // RW + kSP_EL2 = encode(0b11, 0b110, 0b0100, 0b0001, 0b000), // RW + kSSBS = encode(0b11, 0b011, 0b0100, 0b0010, 0b110), // RW + kTCO = encode(0b11, 0b011, 0b0100, 0b0010, 0b111), // RW + kTCR_EL1 = encode(0b11, 0b000, 0b0010, 0b0000, 0b010), // RW + kTCR_EL12 = encode(0b11, 0b101, 0b0010, 0b0000, 0b010), // RW + kTCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0000, 0b010), // RW + kTCR_EL3 = encode(0b11, 0b110, 0b0010, 0b0000, 0b010), // RW + kTEECR32_EL1 = encode(0b10, 0b010, 0b0000, 0b0000, 0b000), // RW + kTEEHBR32_EL1 = encode(0b10, 0b010, 0b0001, 0b0000, 0b000), // RW + kTFSRE0_EL1 = encode(0b11, 0b000, 0b0101, 0b0110, 0b001), // RW + kTFSR_EL1 = encode(0b11, 0b000, 0b0101, 0b0110, 0b000), // RW + kTFSR_EL12 = encode(0b11, 0b101, 0b0101, 0b0110, 0b000), // RW + kTFSR_EL2 = encode(0b11, 0b100, 0b0101, 0b0110, 0b000), // RW + kTFSR_EL3 = encode(0b11, 0b110, 0b0101, 0b0110, 0b000), // RW + kTPIDRRO_EL0 = encode(0b11, 0b011, 0b1101, 0b0000, 0b011), // RW + kTPIDR_EL0 = encode(0b11, 0b011, 0b1101, 0b0000, 0b010), // RW + kTPIDR_EL1 = encode(0b11, 0b000, 0b1101, 0b0000, 0b100), // RW + kTPIDR_EL2 = encode(0b11, 0b100, 0b1101, 0b0000, 0b010), // RW + kTPIDR_EL3 = encode(0b11, 0b110, 0b1101, 0b0000, 0b010), // RW + kTRBBASER_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b010), // RW + kTRBIDR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b111), // RO + kTRBLIMITR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b000), // RW + kTRBMAR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b100), // RW + kTRBPTR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b001), // RW + kTRBSR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b011), // RW + kTRBTRG_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b110), // RW + kTRCACATR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b010), // RW + kTRCACATR1 = encode(0b10, 0b001, 0b0010, 0b0010, 0b010), // RW + kTRCACATR10 = encode(0b10, 0b001, 0b0010, 0b0100, 0b011), // RW + kTRCACATR11 = encode(0b10, 0b001, 0b0010, 0b0110, 0b011), // RW + kTRCACATR12 = encode(0b10, 0b001, 0b0010, 0b1000, 0b011), // RW + kTRCACATR13 = encode(0b10, 0b001, 0b0010, 0b1010, 0b011), // RW + kTRCACATR14 = encode(0b10, 0b001, 0b0010, 0b1100, 0b011), // RW + kTRCACATR15 = encode(0b10, 0b001, 0b0010, 0b1110, 0b011), // RW + kTRCACATR2 = encode(0b10, 0b001, 0b0010, 0b0100, 0b010), // RW + kTRCACATR3 = encode(0b10, 0b001, 0b0010, 0b0110, 0b010), // RW + kTRCACATR4 = encode(0b10, 0b001, 0b0010, 0b1000, 0b010), // RW + kTRCACATR5 = encode(0b10, 0b001, 0b0010, 0b1010, 0b010), // RW + kTRCACATR6 = encode(0b10, 0b001, 0b0010, 0b1100, 0b010), // RW + kTRCACATR7 = encode(0b10, 0b001, 0b0010, 0b1110, 0b010), // RW + kTRCACATR8 = encode(0b10, 0b001, 0b0010, 0b0000, 0b011), // RW + kTRCACATR9 = encode(0b10, 0b001, 0b0010, 0b0010, 0b011), // RW + kTRCACVR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b000), // RW + kTRCACVR1 = encode(0b10, 0b001, 0b0010, 0b0010, 0b000), // RW + kTRCACVR10 = encode(0b10, 0b001, 0b0010, 0b0100, 0b001), // RW + kTRCACVR11 = encode(0b10, 0b001, 0b0010, 0b0110, 0b001), // RW + kTRCACVR12 = encode(0b10, 0b001, 0b0010, 0b1000, 0b001), // RW + kTRCACVR13 = encode(0b10, 0b001, 0b0010, 0b1010, 0b001), // RW + kTRCACVR14 = encode(0b10, 0b001, 0b0010, 0b1100, 0b001), // RW + kTRCACVR15 = encode(0b10, 0b001, 0b0010, 0b1110, 0b001), // RW + kTRCACVR2 = encode(0b10, 0b001, 0b0010, 0b0100, 0b000), // RW + kTRCACVR3 = encode(0b10, 0b001, 0b0010, 0b0110, 0b000), // RW + kTRCACVR4 = encode(0b10, 0b001, 0b0010, 0b1000, 0b000), // RW + kTRCACVR5 = encode(0b10, 0b001, 0b0010, 0b1010, 0b000), // RW + kTRCACVR6 = encode(0b10, 0b001, 0b0010, 0b1100, 0b000), // RW + kTRCACVR7 = encode(0b10, 0b001, 0b0010, 0b1110, 0b000), // RW + kTRCACVR8 = encode(0b10, 0b001, 0b0010, 0b0000, 0b001), // RW + kTRCACVR9 = encode(0b10, 0b001, 0b0010, 0b0010, 0b001), // RW + kTRCAUTHSTATUS = encode(0b10, 0b001, 0b0111, 0b1110, 0b110), // RO + kTRCAUXCTLR = encode(0b10, 0b001, 0b0000, 0b0110, 0b000), // RW + kTRCBBCTLR = encode(0b10, 0b001, 0b0000, 0b1111, 0b000), // RW + kTRCCCCTLR = encode(0b10, 0b001, 0b0000, 0b1110, 0b000), // RW + kTRCCIDCCTLR0 = encode(0b10, 0b001, 0b0011, 0b0000, 0b010), // RW + kTRCCIDCCTLR1 = encode(0b10, 0b001, 0b0011, 0b0001, 0b010), // RW + kTRCCIDCVR0 = encode(0b10, 0b001, 0b0011, 0b0000, 0b000), // RW + kTRCCIDCVR1 = encode(0b10, 0b001, 0b0011, 0b0010, 0b000), // RW + kTRCCIDCVR2 = encode(0b10, 0b001, 0b0011, 0b0100, 0b000), // RW + kTRCCIDCVR3 = encode(0b10, 0b001, 0b0011, 0b0110, 0b000), // RW + kTRCCIDCVR4 = encode(0b10, 0b001, 0b0011, 0b1000, 0b000), // RW + kTRCCIDCVR5 = encode(0b10, 0b001, 0b0011, 0b1010, 0b000), // RW + kTRCCIDCVR6 = encode(0b10, 0b001, 0b0011, 0b1100, 0b000), // RW + kTRCCIDCVR7 = encode(0b10, 0b001, 0b0011, 0b1110, 0b000), // RW + kTRCCIDR0 = encode(0b10, 0b001, 0b0111, 0b1100, 0b111), // RO + kTRCCIDR1 = encode(0b10, 0b001, 0b0111, 0b1101, 0b111), // RO + kTRCCIDR2 = encode(0b10, 0b001, 0b0111, 0b1110, 0b111), // RO + kTRCCIDR3 = encode(0b10, 0b001, 0b0111, 0b1111, 0b111), // RO + kTRCCLAIMCLR = encode(0b10, 0b001, 0b0111, 0b1001, 0b110), // RW + kTRCCLAIMSET = encode(0b10, 0b001, 0b0111, 0b1000, 0b110), // RW + kTRCCNTCTLR0 = encode(0b10, 0b001, 0b0000, 0b0100, 0b101), // RW + kTRCCNTCTLR1 = encode(0b10, 0b001, 0b0000, 0b0101, 0b101), // RW + kTRCCNTCTLR2 = encode(0b10, 0b001, 0b0000, 0b0110, 0b101), // RW + kTRCCNTCTLR3 = encode(0b10, 0b001, 0b0000, 0b0111, 0b101), // RW + kTRCCNTRLDVR0 = encode(0b10, 0b001, 0b0000, 0b0000, 0b101), // RW + kTRCCNTRLDVR1 = encode(0b10, 0b001, 0b0000, 0b0001, 0b101), // RW + kTRCCNTRLDVR2 = encode(0b10, 0b001, 0b0000, 0b0010, 0b101), // RW + kTRCCNTRLDVR3 = encode(0b10, 0b001, 0b0000, 0b0011, 0b101), // RW + kTRCCNTVR0 = encode(0b10, 0b001, 0b0000, 0b1000, 0b101), // RW + kTRCCNTVR1 = encode(0b10, 0b001, 0b0000, 0b1001, 0b101), // RW + kTRCCNTVR2 = encode(0b10, 0b001, 0b0000, 0b1010, 0b101), // RW + kTRCCNTVR3 = encode(0b10, 0b001, 0b0000, 0b1011, 0b101), // RW + kTRCCONFIGR = encode(0b10, 0b001, 0b0000, 0b0100, 0b000), // RW + kTRCDEVAFF0 = encode(0b10, 0b001, 0b0111, 0b1010, 0b110), // RO + kTRCDEVAFF1 = encode(0b10, 0b001, 0b0111, 0b1011, 0b110), // RO + kTRCDEVARCH = encode(0b10, 0b001, 0b0111, 0b1111, 0b110), // RO + kTRCDEVID = encode(0b10, 0b001, 0b0111, 0b0010, 0b111), // RO + kTRCDEVTYPE = encode(0b10, 0b001, 0b0111, 0b0011, 0b111), // RO + kTRCDVCMR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b110), // RW + kTRCDVCMR1 = encode(0b10, 0b001, 0b0010, 0b0100, 0b110), // RW + kTRCDVCMR2 = encode(0b10, 0b001, 0b0010, 0b1000, 0b110), // RW + kTRCDVCMR3 = encode(0b10, 0b001, 0b0010, 0b1100, 0b110), // RW + kTRCDVCMR4 = encode(0b10, 0b001, 0b0010, 0b0000, 0b111), // RW + kTRCDVCMR5 = encode(0b10, 0b001, 0b0010, 0b0100, 0b111), // RW + kTRCDVCMR6 = encode(0b10, 0b001, 0b0010, 0b1000, 0b111), // RW + kTRCDVCMR7 = encode(0b10, 0b001, 0b0010, 0b1100, 0b111), // RW + kTRCDVCVR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b100), // RW + kTRCDVCVR1 = encode(0b10, 0b001, 0b0010, 0b0100, 0b100), // RW + kTRCDVCVR2 = encode(0b10, 0b001, 0b0010, 0b1000, 0b100), // RW + kTRCDVCVR3 = encode(0b10, 0b001, 0b0010, 0b1100, 0b100), // RW + kTRCDVCVR4 = encode(0b10, 0b001, 0b0010, 0b0000, 0b101), // RW + kTRCDVCVR5 = encode(0b10, 0b001, 0b0010, 0b0100, 0b101), // RW + kTRCDVCVR6 = encode(0b10, 0b001, 0b0010, 0b1000, 0b101), // RW + kTRCDVCVR7 = encode(0b10, 0b001, 0b0010, 0b1100, 0b101), // RW + kTRCEVENTCTL0R = encode(0b10, 0b001, 0b0000, 0b1000, 0b000), // RW + kTRCEVENTCTL1R = encode(0b10, 0b001, 0b0000, 0b1001, 0b000), // RW + kTRCEXTINSELR = encode(0b10, 0b001, 0b0000, 0b1000, 0b100), // RW + kTRCEXTINSELR0 = encode(0b10, 0b001, 0b0000, 0b1000, 0b100), // RW + kTRCEXTINSELR1 = encode(0b10, 0b001, 0b0000, 0b1001, 0b100), // RW + kTRCEXTINSELR2 = encode(0b10, 0b001, 0b0000, 0b1010, 0b100), // RW + kTRCEXTINSELR3 = encode(0b10, 0b001, 0b0000, 0b1011, 0b100), // RW + kTRCIDR0 = encode(0b10, 0b001, 0b0000, 0b1000, 0b111), // RO + kTRCIDR1 = encode(0b10, 0b001, 0b0000, 0b1001, 0b111), // RO + kTRCIDR10 = encode(0b10, 0b001, 0b0000, 0b0010, 0b110), // RO + kTRCIDR11 = encode(0b10, 0b001, 0b0000, 0b0011, 0b110), // RO + kTRCIDR12 = encode(0b10, 0b001, 0b0000, 0b0100, 0b110), // RO + kTRCIDR13 = encode(0b10, 0b001, 0b0000, 0b0101, 0b110), // RO + kTRCIDR2 = encode(0b10, 0b001, 0b0000, 0b1010, 0b111), // RO + kTRCIDR3 = encode(0b10, 0b001, 0b0000, 0b1011, 0b111), // RO + kTRCIDR4 = encode(0b10, 0b001, 0b0000, 0b1100, 0b111), // RO + kTRCIDR5 = encode(0b10, 0b001, 0b0000, 0b1101, 0b111), // RO + kTRCIDR6 = encode(0b10, 0b001, 0b0000, 0b1110, 0b111), // RO + kTRCIDR7 = encode(0b10, 0b001, 0b0000, 0b1111, 0b111), // RO + kTRCIDR8 = encode(0b10, 0b001, 0b0000, 0b0000, 0b110), // RO + kTRCIDR9 = encode(0b10, 0b001, 0b0000, 0b0001, 0b110), // RO + kTRCIMSPEC0 = encode(0b10, 0b001, 0b0000, 0b0000, 0b111), // RW + kTRCIMSPEC1 = encode(0b10, 0b001, 0b0000, 0b0001, 0b111), // RW + kTRCIMSPEC2 = encode(0b10, 0b001, 0b0000, 0b0010, 0b111), // RW + kTRCIMSPEC3 = encode(0b10, 0b001, 0b0000, 0b0011, 0b111), // RW + kTRCIMSPEC4 = encode(0b10, 0b001, 0b0000, 0b0100, 0b111), // RW + kTRCIMSPEC5 = encode(0b10, 0b001, 0b0000, 0b0101, 0b111), // RW + kTRCIMSPEC6 = encode(0b10, 0b001, 0b0000, 0b0110, 0b111), // RW + kTRCIMSPEC7 = encode(0b10, 0b001, 0b0000, 0b0111, 0b111), // RW + kTRCITCTRL = encode(0b10, 0b001, 0b0111, 0b0000, 0b100), // RW + kTRCLAR = encode(0b10, 0b001, 0b0111, 0b1100, 0b110), // WO + kTRCLSR = encode(0b10, 0b001, 0b0111, 0b1101, 0b110), // RO + kTRCOSLAR = encode(0b10, 0b001, 0b0001, 0b0000, 0b100), // WO + kTRCOSLSR = encode(0b10, 0b001, 0b0001, 0b0001, 0b100), // RO + kTRCPDCR = encode(0b10, 0b001, 0b0001, 0b0100, 0b100), // RW + kTRCPDSR = encode(0b10, 0b001, 0b0001, 0b0101, 0b100), // RO + kTRCPIDR0 = encode(0b10, 0b001, 0b0111, 0b1000, 0b111), // RO + kTRCPIDR1 = encode(0b10, 0b001, 0b0111, 0b1001, 0b111), // RO + kTRCPIDR2 = encode(0b10, 0b001, 0b0111, 0b1010, 0b111), // RO + kTRCPIDR3 = encode(0b10, 0b001, 0b0111, 0b1011, 0b111), // RO + kTRCPIDR4 = encode(0b10, 0b001, 0b0111, 0b0100, 0b111), // RO + kTRCPIDR5 = encode(0b10, 0b001, 0b0111, 0b0101, 0b111), // RO + kTRCPIDR6 = encode(0b10, 0b001, 0b0111, 0b0110, 0b111), // RO + kTRCPIDR7 = encode(0b10, 0b001, 0b0111, 0b0111, 0b111), // RO + kTRCPRGCTLR = encode(0b10, 0b001, 0b0000, 0b0001, 0b000), // RW + kTRCPROCSELR = encode(0b10, 0b001, 0b0000, 0b0010, 0b000), // RW + kTRCQCTLR = encode(0b10, 0b001, 0b0000, 0b0001, 0b001), // RW + kTRCRSCTLR10 = encode(0b10, 0b001, 0b0001, 0b1010, 0b000), // RW + kTRCRSCTLR11 = encode(0b10, 0b001, 0b0001, 0b1011, 0b000), // RW + kTRCRSCTLR12 = encode(0b10, 0b001, 0b0001, 0b1100, 0b000), // RW + kTRCRSCTLR13 = encode(0b10, 0b001, 0b0001, 0b1101, 0b000), // RW + kTRCRSCTLR14 = encode(0b10, 0b001, 0b0001, 0b1110, 0b000), // RW + kTRCRSCTLR15 = encode(0b10, 0b001, 0b0001, 0b1111, 0b000), // RW + kTRCRSCTLR16 = encode(0b10, 0b001, 0b0001, 0b0000, 0b001), // RW + kTRCRSCTLR17 = encode(0b10, 0b001, 0b0001, 0b0001, 0b001), // RW + kTRCRSCTLR18 = encode(0b10, 0b001, 0b0001, 0b0010, 0b001), // RW + kTRCRSCTLR19 = encode(0b10, 0b001, 0b0001, 0b0011, 0b001), // RW + kTRCRSCTLR2 = encode(0b10, 0b001, 0b0001, 0b0010, 0b000), // RW + kTRCRSCTLR20 = encode(0b10, 0b001, 0b0001, 0b0100, 0b001), // RW + kTRCRSCTLR21 = encode(0b10, 0b001, 0b0001, 0b0101, 0b001), // RW + kTRCRSCTLR22 = encode(0b10, 0b001, 0b0001, 0b0110, 0b001), // RW + kTRCRSCTLR23 = encode(0b10, 0b001, 0b0001, 0b0111, 0b001), // RW + kTRCRSCTLR24 = encode(0b10, 0b001, 0b0001, 0b1000, 0b001), // RW + kTRCRSCTLR25 = encode(0b10, 0b001, 0b0001, 0b1001, 0b001), // RW + kTRCRSCTLR26 = encode(0b10, 0b001, 0b0001, 0b1010, 0b001), // RW + kTRCRSCTLR27 = encode(0b10, 0b001, 0b0001, 0b1011, 0b001), // RW + kTRCRSCTLR28 = encode(0b10, 0b001, 0b0001, 0b1100, 0b001), // RW + kTRCRSCTLR29 = encode(0b10, 0b001, 0b0001, 0b1101, 0b001), // RW + kTRCRSCTLR3 = encode(0b10, 0b001, 0b0001, 0b0011, 0b000), // RW + kTRCRSCTLR30 = encode(0b10, 0b001, 0b0001, 0b1110, 0b001), // RW + kTRCRSCTLR31 = encode(0b10, 0b001, 0b0001, 0b1111, 0b001), // RW + kTRCRSCTLR4 = encode(0b10, 0b001, 0b0001, 0b0100, 0b000), // RW + kTRCRSCTLR5 = encode(0b10, 0b001, 0b0001, 0b0101, 0b000), // RW + kTRCRSCTLR6 = encode(0b10, 0b001, 0b0001, 0b0110, 0b000), // RW + kTRCRSCTLR7 = encode(0b10, 0b001, 0b0001, 0b0111, 0b000), // RW + kTRCRSCTLR8 = encode(0b10, 0b001, 0b0001, 0b1000, 0b000), // RW + kTRCRSCTLR9 = encode(0b10, 0b001, 0b0001, 0b1001, 0b000), // RW + kTRCRSR = encode(0b10, 0b001, 0b0000, 0b1010, 0b000), // RW + kTRCSEQEVR0 = encode(0b10, 0b001, 0b0000, 0b0000, 0b100), // RW + kTRCSEQEVR1 = encode(0b10, 0b001, 0b0000, 0b0001, 0b100), // RW + kTRCSEQEVR2 = encode(0b10, 0b001, 0b0000, 0b0010, 0b100), // RW + kTRCSEQRSTEVR = encode(0b10, 0b001, 0b0000, 0b0110, 0b100), // RW + kTRCSEQSTR = encode(0b10, 0b001, 0b0000, 0b0111, 0b100), // RW + kTRCSSCCR0 = encode(0b10, 0b001, 0b0001, 0b0000, 0b010), // RW + kTRCSSCCR1 = encode(0b10, 0b001, 0b0001, 0b0001, 0b010), // RW + kTRCSSCCR2 = encode(0b10, 0b001, 0b0001, 0b0010, 0b010), // RW + kTRCSSCCR3 = encode(0b10, 0b001, 0b0001, 0b0011, 0b010), // RW + kTRCSSCCR4 = encode(0b10, 0b001, 0b0001, 0b0100, 0b010), // RW + kTRCSSCCR5 = encode(0b10, 0b001, 0b0001, 0b0101, 0b010), // RW + kTRCSSCCR6 = encode(0b10, 0b001, 0b0001, 0b0110, 0b010), // RW + kTRCSSCCR7 = encode(0b10, 0b001, 0b0001, 0b0111, 0b010), // RW + kTRCSSCSR0 = encode(0b10, 0b001, 0b0001, 0b1000, 0b010), // RW + kTRCSSCSR1 = encode(0b10, 0b001, 0b0001, 0b1001, 0b010), // RW + kTRCSSCSR2 = encode(0b10, 0b001, 0b0001, 0b1010, 0b010), // RW + kTRCSSCSR3 = encode(0b10, 0b001, 0b0001, 0b1011, 0b010), // RW + kTRCSSCSR4 = encode(0b10, 0b001, 0b0001, 0b1100, 0b010), // RW + kTRCSSCSR5 = encode(0b10, 0b001, 0b0001, 0b1101, 0b010), // RW + kTRCSSCSR6 = encode(0b10, 0b001, 0b0001, 0b1110, 0b010), // RW + kTRCSSCSR7 = encode(0b10, 0b001, 0b0001, 0b1111, 0b010), // RW + kTRCSSPCICR0 = encode(0b10, 0b001, 0b0001, 0b0000, 0b011), // RW + kTRCSSPCICR1 = encode(0b10, 0b001, 0b0001, 0b0001, 0b011), // RW + kTRCSSPCICR2 = encode(0b10, 0b001, 0b0001, 0b0010, 0b011), // RW + kTRCSSPCICR3 = encode(0b10, 0b001, 0b0001, 0b0011, 0b011), // RW + kTRCSSPCICR4 = encode(0b10, 0b001, 0b0001, 0b0100, 0b011), // RW + kTRCSSPCICR5 = encode(0b10, 0b001, 0b0001, 0b0101, 0b011), // RW + kTRCSSPCICR6 = encode(0b10, 0b001, 0b0001, 0b0110, 0b011), // RW + kTRCSSPCICR7 = encode(0b10, 0b001, 0b0001, 0b0111, 0b011), // RW + kTRCSTALLCTLR = encode(0b10, 0b001, 0b0000, 0b1011, 0b000), // RW + kTRCSTATR = encode(0b10, 0b001, 0b0000, 0b0011, 0b000), // RO + kTRCSYNCPR = encode(0b10, 0b001, 0b0000, 0b1101, 0b000), // RW + kTRCTRACEIDR = encode(0b10, 0b001, 0b0000, 0b0000, 0b001), // RW + kTRCTSCTLR = encode(0b10, 0b001, 0b0000, 0b1100, 0b000), // RW + kTRCVDARCCTLR = encode(0b10, 0b001, 0b0000, 0b1010, 0b010), // RW + kTRCVDCTLR = encode(0b10, 0b001, 0b0000, 0b1000, 0b010), // RW + kTRCVDSACCTLR = encode(0b10, 0b001, 0b0000, 0b1001, 0b010), // RW + kTRCVICTLR = encode(0b10, 0b001, 0b0000, 0b0000, 0b010), // RW + kTRCVIIECTLR = encode(0b10, 0b001, 0b0000, 0b0001, 0b010), // RW + kTRCVIPCSSCTLR = encode(0b10, 0b001, 0b0000, 0b0011, 0b010), // RW + kTRCVISSCTLR = encode(0b10, 0b001, 0b0000, 0b0010, 0b010), // RW + kTRCVMIDCCTLR0 = encode(0b10, 0b001, 0b0011, 0b0010, 0b010), // RW + kTRCVMIDCCTLR1 = encode(0b10, 0b001, 0b0011, 0b0011, 0b010), // RW + kTRCVMIDCVR0 = encode(0b10, 0b001, 0b0011, 0b0000, 0b001), // RW + kTRCVMIDCVR1 = encode(0b10, 0b001, 0b0011, 0b0010, 0b001), // RW + kTRCVMIDCVR2 = encode(0b10, 0b001, 0b0011, 0b0100, 0b001), // RW + kTRCVMIDCVR3 = encode(0b10, 0b001, 0b0011, 0b0110, 0b001), // RW + kTRCVMIDCVR4 = encode(0b10, 0b001, 0b0011, 0b1000, 0b001), // RW + kTRCVMIDCVR5 = encode(0b10, 0b001, 0b0011, 0b1010, 0b001), // RW + kTRCVMIDCVR6 = encode(0b10, 0b001, 0b0011, 0b1100, 0b001), // RW + kTRCVMIDCVR7 = encode(0b10, 0b001, 0b0011, 0b1110, 0b001), // RW + kTRFCR_EL1 = encode(0b11, 0b000, 0b0001, 0b0010, 0b001), // RW + kTRFCR_EL12 = encode(0b11, 0b101, 0b0001, 0b0010, 0b001), // RW + kTRFCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0010, 0b001), // RW + kTTBR0_EL1 = encode(0b11, 0b000, 0b0010, 0b0000, 0b000), // RW + kTTBR0_EL12 = encode(0b11, 0b101, 0b0010, 0b0000, 0b000), // RW + kTTBR0_EL2 = encode(0b11, 0b100, 0b0010, 0b0000, 0b000), // RW + kTTBR0_EL3 = encode(0b11, 0b110, 0b0010, 0b0000, 0b000), // RW + kTTBR1_EL1 = encode(0b11, 0b000, 0b0010, 0b0000, 0b001), // RW + kTTBR1_EL12 = encode(0b11, 0b101, 0b0010, 0b0000, 0b001), // RW + kTTBR1_EL2 = encode(0b11, 0b100, 0b0010, 0b0000, 0b001), // RW + kUAO = encode(0b11, 0b000, 0b0100, 0b0010, 0b100), // RW + kVBAR_EL1 = encode(0b11, 0b000, 0b1100, 0b0000, 0b000), // RW + kVBAR_EL12 = encode(0b11, 0b101, 0b1100, 0b0000, 0b000), // RW + kVBAR_EL2 = encode(0b11, 0b100, 0b1100, 0b0000, 0b000), // RW + kVBAR_EL3 = encode(0b11, 0b110, 0b1100, 0b0000, 0b000), // RW + kVDISR_EL2 = encode(0b11, 0b100, 0b1100, 0b0001, 0b001), // RW + kVMPIDR_EL2 = encode(0b11, 0b100, 0b0000, 0b0000, 0b101), // RW + kVNCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0010, 0b000), // RW + kVPIDR_EL2 = encode(0b11, 0b100, 0b0000, 0b0000, 0b000), // RW + kVSESR_EL2 = encode(0b11, 0b100, 0b0101, 0b0010, 0b011), // RW + kVSTCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0110, 0b010), // RW + kVSTTBR_EL2 = encode(0b11, 0b100, 0b0010, 0b0110, 0b000), // RW + kVTCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0001, 0b010), // RW + kVTTBR_EL2 = encode(0b11, 0b100, 0b0010, 0b0001, 0b000), // RW + kZCR_EL1 = encode(0b11, 0b000, 0b0001, 0b0010, 0b000), // RW + kZCR_EL12 = encode(0b11, 0b101, 0b0001, 0b0010, 0b000), // RW + kZCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0010, 0b000), // RW + kZCR_EL3 = encode(0b11, 0b110, 0b0001, 0b0010, 0b000) // RW + }; +}; + +} // {Predicate} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_A64GLOBALS_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64instdb.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64instdb.h new file mode 100644 index 0000000000000000000000000000000000000000..a03125401ad987c87bbf5d3428c152e7fe5ea722 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64instdb.h @@ -0,0 +1,72 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64INSTDB_H_INCLUDED +#define ASMJIT_ARM_A64INSTDB_H_INCLUDED + +#include "../arm/a64globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! Instruction database (AArch64). +namespace InstDB { + +//! Instruction flags. +enum InstFlags : uint32_t { + //! The instruction provides conditional execution. + kInstFlagCond = 0x00000001u, + //! SIMD instruction that processes elements in pairs. + kInstFlagPair = 0x00000002u, + //! SIMD instruction that does widening (Long). + kInstFlagLong = 0x00000004u, + //! SIMD instruction that does narrowing (Narrow). + kInstFlagNarrow = 0x00000008u, + //! SIMD element access of half-words can only be used with v0..15. + kInstFlagVH0_15 = 0x00000010u, + + //! Instruction uses consecutive registers if the number of operands is greater than 2. + kInstFlagConsecutive = 0x00000080u +}; + +//! Instruction information (AArch64). +struct InstInfo { + //! Instruction encoding type. + uint32_t _encoding : 8; + //! Index to data specific to each encoding type. + uint32_t _encodingDataIndex : 8; + uint32_t _reserved : 16; + + uint16_t _rwInfoIndex; + uint16_t _flags; + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG uint32_t rwInfoIndex() const noexcept { return _rwInfoIndex; } + ASMJIT_INLINE_NODEBUG uint32_t flags() const noexcept { return _flags; } + + ASMJIT_INLINE_NODEBUG bool hasFlag(uint32_t flag) const { return (_flags & flag) != 0; } + + //! \} +}; + +ASMJIT_VARAPI const InstInfo _instInfoTable[]; + +static inline const InstInfo& infoById(InstId instId) noexcept { + instId &= uint32_t(InstIdParts::kRealId); + ASMJIT_ASSERT(Inst::isDefinedId(instId)); + return _instInfoTable[instId]; +} + +} // {InstDB} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_A64INSTDB_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64operand.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64operand.h new file mode 100644 index 0000000000000000000000000000000000000000..c64f20eb4457cb0ae9e376d177452c09c7f8b47b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/a64operand.h @@ -0,0 +1,650 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64OPERAND_H_INCLUDED +#define ASMJIT_ARM_A64OPERAND_H_INCLUDED + +#include "../arm/armoperand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +class GpW; +class GpX; + +class VecB; +class VecH; +class VecS; +class VecD; +class VecV; + +//! General purpose register (AArch64). +class Gp : public Reg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Gp, Reg) + + //! Special register id. + enum Id : uint32_t { + //! Register that depends on OS, could be used as TLS offset. + kIdOs = 18, + //! Frame pointer register id. + kIdFp = 29, + //! Link register id. + kIdLr = 30, + //! Stack register id. + kIdSp = 31, + //! Zero register id. + //! + //! Although zero register has the same id as stack register it has a special treatment, because we need to be + //! able to distinguish between these two at API level. Some instructions were designed to be used with SP and + //! some other with ZR - so we need a way to distinguish these two to make sure we emit the right thing. + //! + //! The number 63 is not random, when you perform `id & 31` you would always get 31 for both SP and ZR inputs, + //! which is the identifier used by AArch64 ISA to encode either SP or ZR depending on the instruction. + kIdZr = 63 + }; + + //! Test whether this register is ZR register. + ASMJIT_INLINE_NODEBUG constexpr bool isZR() const noexcept { return id() == kIdZr; } + //! Test whether this register is SP register. + ASMJIT_INLINE_NODEBUG constexpr bool isSP() const noexcept { return id() == kIdSp; } + + //! Cast this register to a 32-bit W register (returns a new operand). + ASMJIT_INLINE_NODEBUG GpW w() const noexcept; + //! \overload + ASMJIT_INLINE_NODEBUG GpW r32() const noexcept; + //! Cast this register to a 64-bit X register (returns a new operand). + ASMJIT_INLINE_NODEBUG GpX x() const noexcept; + //! \overload + ASMJIT_INLINE_NODEBUG GpX r64() const noexcept; +}; + +//! 32-bit general purpose W register (AArch64). +class GpW : public Gp { ASMJIT_DEFINE_FINAL_REG(GpW, Gp, RegTraits); }; +//! 64-bit general purpose X register (AArch64). +class GpX : public Gp { ASMJIT_DEFINE_FINAL_REG(GpX, Gp, RegTraits); }; + +#ifndef _DOXYGEN +ASMJIT_INLINE_NODEBUG GpW Gp::w() const noexcept { return GpW(id()); } +ASMJIT_INLINE_NODEBUG GpX Gp::x() const noexcept { return GpX(id()); } +ASMJIT_INLINE_NODEBUG GpW Gp::r32() const noexcept { return GpW(id()); } +ASMJIT_INLINE_NODEBUG GpX Gp::r64() const noexcept { return GpX(id()); } +#endif + +//! Vector element type (AArch64). +enum class VecElementType : uint32_t { + //! No element type specified. + kNone = 0, + //! Byte elements (B8 or B16). + kB, + //! Halfword elements (H4 or H8). + kH, + //! Singleword elements (S2 or S4). + kS, + //! Doubleword elements (D2). + kD, + //! Byte elements grouped by 4 bytes (B4). + //! + //! \note This element-type is only used by few instructions. + kB4, + //! Halfword elements grouped by 2 halfwords (H2). + //! + //! \note This element-type is only used by few instructions. + kH2, + + //! Maximum value of \ref VecElementType + kMaxValue = kH2 +}; + +//! Vector register (AArch64). +class Vec : public BaseVec { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Vec, BaseVec) + + //! \cond + //! Shortcuts. + enum SignatureReg : uint32_t { + kSignatureElementB = uint32_t(VecElementType::kB) << kSignatureRegElementTypeShift, + kSignatureElementH = uint32_t(VecElementType::kH) << kSignatureRegElementTypeShift, + kSignatureElementS = uint32_t(VecElementType::kS) << kSignatureRegElementTypeShift, + kSignatureElementD = uint32_t(VecElementType::kD) << kSignatureRegElementTypeShift, + kSignatureElementB4 = uint32_t(VecElementType::kB4) << kSignatureRegElementTypeShift, + kSignatureElementH2 = uint32_t(VecElementType::kH2) << kSignatureRegElementTypeShift + }; + //! \endcond + + //! Returns whether the register has element type or element index (or both). + ASMJIT_INLINE_NODEBUG constexpr bool hasElementTypeOrIndex() const noexcept { return _signature.hasField(); } + + //! Returns whether the vector register has associated a vector element type. + ASMJIT_INLINE_NODEBUG constexpr bool hasElementType() const noexcept { return _signature.hasField(); } + //! Returns vector element type of the register. + ASMJIT_INLINE_NODEBUG constexpr VecElementType elementType() const noexcept { return VecElementType(_signature.getField()); } + //! Sets vector element type of the register to `elementType`. + ASMJIT_INLINE_NODEBUG void setElementType(VecElementType elementType) noexcept { _signature.setField(uint32_t(elementType)); } + //! Resets vector element type to none. + ASMJIT_INLINE_NODEBUG void resetElementType() noexcept { _signature.setField(0); } + + ASMJIT_INLINE_NODEBUG constexpr bool isVecB8() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementB); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecH4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementH); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecS2() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementS); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecD1() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature); } + + ASMJIT_INLINE_NODEBUG constexpr bool isVecB16() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementB); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecH8() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementH); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecS4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementS); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecD2() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementD); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecB4x4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementB4); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecH2x4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementH2); } + + //! Creates a cloned register with element access. + ASMJIT_INLINE_NODEBUG Vec at(uint32_t elementIndex) const noexcept { + return Vec((signature() & ~kSignatureRegElementIndexMask) | (elementIndex << kSignatureRegElementIndexShift) | kSignatureRegElementFlagMask, id()); + } + + //! Cast this register to an 8-bit B register (AArch64 only). + ASMJIT_INLINE_NODEBUG VecB b() const noexcept; + //! Cast this register to a 16-bit H register (AArch64 only). + ASMJIT_INLINE_NODEBUG VecH h() const noexcept; + //! Cast this register to a 32-bit S register. + ASMJIT_INLINE_NODEBUG VecS s() const noexcept; + //! Cast this register to a 64-bit D register. + ASMJIT_INLINE_NODEBUG VecD d() const noexcept; + //! Cast this register to a 128-bit Q register. + ASMJIT_INLINE_NODEBUG VecV q() const noexcept; + //! Cast this register to a 128-bit V register. + ASMJIT_INLINE_NODEBUG VecV v() const noexcept; + + //! Casts this register to b (clone). + ASMJIT_INLINE_NODEBUG Vec v8() const noexcept; + //! Casts this register to h (clone). + ASMJIT_INLINE_NODEBUG Vec v16() const noexcept; + //! Casts this register to s (clone). + ASMJIT_INLINE_NODEBUG Vec v32() const noexcept; + //! Casts this register to d (clone). + ASMJIT_INLINE_NODEBUG Vec v64() const noexcept; + //! Casts this register to q (clone). + ASMJIT_INLINE_NODEBUG Vec v128() const noexcept; + + //! Cast this register to a 128-bit V.B[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV b(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.H[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV h(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.S[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV s(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.D[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV d(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.H2[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV h2(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.B4[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV b4(uint32_t elementIndex) const noexcept; + + //! Cast this register to V.8B. + ASMJIT_INLINE_NODEBUG VecD b8() const noexcept; + //! Cast this register to V.16B. + ASMJIT_INLINE_NODEBUG VecV b16() const noexcept; + //! Cast this register to V.2H. + ASMJIT_INLINE_NODEBUG VecS h2() const noexcept; + //! Cast this register to V.4H. + ASMJIT_INLINE_NODEBUG VecD h4() const noexcept; + //! Cast this register to V.8H. + ASMJIT_INLINE_NODEBUG VecV h8() const noexcept; + //! Cast this register to V.2S. + ASMJIT_INLINE_NODEBUG VecD s2() const noexcept; + //! Cast this register to V.4S. + ASMJIT_INLINE_NODEBUG VecV s4() const noexcept; + //! Cast this register to V.2D. + ASMJIT_INLINE_NODEBUG VecV d2() const noexcept; + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature _makeElementAccessSignature(VecElementType elementType, uint32_t elementIndex) noexcept { + return OperandSignature{ + uint32_t(RegTraits::kSignature) | + uint32_t(kSignatureRegElementFlagMask) | + (uint32_t(elementType) << kSignatureRegElementTypeShift) | + (uint32_t(elementIndex << kSignatureRegElementIndexShift))}; + } +}; + +//! 8-bit view (S) of VFP/SIMD register. +class VecB : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecB, Vec, RegTraits) +}; + +//! 16-bit view (S) of VFP/SIMD register. +class VecH : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecH, Vec, RegTraits) +}; + +//! 32-bit view (S) of VFP/SIMD register. +class VecS : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecS, Vec, RegTraits) +}; + +//! 64-bit view (D) of VFP/SIMD register. +class VecD : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecD, Vec, RegTraits) +}; + +//! 128-bit vector register (Q or V). +class VecV : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecV, Vec, RegTraits) +}; + +ASMJIT_INLINE_NODEBUG VecB Vec::b() const noexcept { return VecB(id()); } +ASMJIT_INLINE_NODEBUG VecH Vec::h() const noexcept { return VecH(id()); } +ASMJIT_INLINE_NODEBUG VecS Vec::s() const noexcept { return VecS(id()); } +ASMJIT_INLINE_NODEBUG VecD Vec::d() const noexcept { return VecD(id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::q() const noexcept { return VecV(id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::v() const noexcept { return VecV(id()); } + +ASMJIT_INLINE_NODEBUG Vec Vec::v8() const noexcept { return VecB(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v16() const noexcept { return VecH(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v32() const noexcept { return VecS(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v64() const noexcept { return VecD(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v128() const noexcept { return VecV(id()); } + +ASMJIT_INLINE_NODEBUG VecV Vec::b(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kB, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::h(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kH, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::s(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kS, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::d(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kD, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::h2(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kH2, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::b4(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kB4, elementIndex), id()); } + +ASMJIT_INLINE_NODEBUG VecD Vec::b8() const noexcept { return VecD(OperandSignature{VecD::kSignature | kSignatureElementB}, id()); } +ASMJIT_INLINE_NODEBUG VecS Vec::h2() const noexcept { return VecS(OperandSignature{VecS::kSignature | kSignatureElementH}, id()); } +ASMJIT_INLINE_NODEBUG VecD Vec::h4() const noexcept { return VecD(OperandSignature{VecD::kSignature | kSignatureElementH}, id()); } +ASMJIT_INLINE_NODEBUG VecD Vec::s2() const noexcept { return VecD(OperandSignature{VecD::kSignature | kSignatureElementS}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::b16() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementB}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::h8() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementH}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::s4() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementS}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::d2() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementD}, id()); } + +#ifndef _DOXYGEN +namespace regs { +#endif + +//! Creates a 32-bit W register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpW w(uint32_t id) noexcept { return GpW(id); } +//! Creates a 64-bit X register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpX x(uint32_t id) noexcept { return GpX(id); } + +//! Creates a 32-bit S register operand. +static ASMJIT_INLINE_NODEBUG constexpr VecS s(uint32_t id) noexcept { return VecS(id); } +//! Creates a 64-bit D register operand. +static ASMJIT_INLINE_NODEBUG constexpr VecD d(uint32_t id) noexcept { return VecD(id); } +//! Creates a 1282-bit V register operand. +static ASMJIT_INLINE_NODEBUG constexpr VecV v(uint32_t id) noexcept { return VecV(id); } + +static constexpr GpW w0 = GpW(0); +static constexpr GpW w1 = GpW(1); +static constexpr GpW w2 = GpW(2); +static constexpr GpW w3 = GpW(3); +static constexpr GpW w4 = GpW(4); +static constexpr GpW w5 = GpW(5); +static constexpr GpW w6 = GpW(6); +static constexpr GpW w7 = GpW(7); +static constexpr GpW w8 = GpW(8); +static constexpr GpW w9 = GpW(9); +static constexpr GpW w10 = GpW(10); +static constexpr GpW w11 = GpW(11); +static constexpr GpW w12 = GpW(12); +static constexpr GpW w13 = GpW(13); +static constexpr GpW w14 = GpW(14); +static constexpr GpW w15 = GpW(15); +static constexpr GpW w16 = GpW(16); +static constexpr GpW w17 = GpW(17); +static constexpr GpW w18 = GpW(18); +static constexpr GpW w19 = GpW(19); +static constexpr GpW w20 = GpW(20); +static constexpr GpW w21 = GpW(21); +static constexpr GpW w22 = GpW(22); +static constexpr GpW w23 = GpW(23); +static constexpr GpW w24 = GpW(24); +static constexpr GpW w25 = GpW(25); +static constexpr GpW w26 = GpW(26); +static constexpr GpW w27 = GpW(27); +static constexpr GpW w28 = GpW(28); +static constexpr GpW w29 = GpW(29); +static constexpr GpW w30 = GpW(30); +static constexpr GpW wzr = GpW(Gp::kIdZr); +static constexpr GpW wsp = GpW(Gp::kIdSp); + +static constexpr GpX x0 = GpX(0); +static constexpr GpX x1 = GpX(1); +static constexpr GpX x2 = GpX(2); +static constexpr GpX x3 = GpX(3); +static constexpr GpX x4 = GpX(4); +static constexpr GpX x5 = GpX(5); +static constexpr GpX x6 = GpX(6); +static constexpr GpX x7 = GpX(7); +static constexpr GpX x8 = GpX(8); +static constexpr GpX x9 = GpX(9); +static constexpr GpX x10 = GpX(10); +static constexpr GpX x11 = GpX(11); +static constexpr GpX x12 = GpX(12); +static constexpr GpX x13 = GpX(13); +static constexpr GpX x14 = GpX(14); +static constexpr GpX x15 = GpX(15); +static constexpr GpX x16 = GpX(16); +static constexpr GpX x17 = GpX(17); +static constexpr GpX x18 = GpX(18); +static constexpr GpX x19 = GpX(19); +static constexpr GpX x20 = GpX(20); +static constexpr GpX x21 = GpX(21); +static constexpr GpX x22 = GpX(22); +static constexpr GpX x23 = GpX(23); +static constexpr GpX x24 = GpX(24); +static constexpr GpX x25 = GpX(25); +static constexpr GpX x26 = GpX(26); +static constexpr GpX x27 = GpX(27); +static constexpr GpX x28 = GpX(28); +static constexpr GpX x29 = GpX(29); +static constexpr GpX x30 = GpX(30); +static constexpr GpX xzr = GpX(Gp::kIdZr); +static constexpr GpX sp = GpX(Gp::kIdSp); + +static constexpr VecB b0 = VecB(0); +static constexpr VecB b1 = VecB(1); +static constexpr VecB b2 = VecB(2); +static constexpr VecB b3 = VecB(3); +static constexpr VecB b4 = VecB(4); +static constexpr VecB b5 = VecB(5); +static constexpr VecB b6 = VecB(6); +static constexpr VecB b7 = VecB(7); +static constexpr VecB b8 = VecB(8); +static constexpr VecB b9 = VecB(9); +static constexpr VecB b10 = VecB(10); +static constexpr VecB b11 = VecB(11); +static constexpr VecB b12 = VecB(12); +static constexpr VecB b13 = VecB(13); +static constexpr VecB b14 = VecB(14); +static constexpr VecB b15 = VecB(15); +static constexpr VecB b16 = VecB(16); +static constexpr VecB b17 = VecB(17); +static constexpr VecB b18 = VecB(18); +static constexpr VecB b19 = VecB(19); +static constexpr VecB b20 = VecB(20); +static constexpr VecB b21 = VecB(21); +static constexpr VecB b22 = VecB(22); +static constexpr VecB b23 = VecB(23); +static constexpr VecB b24 = VecB(24); +static constexpr VecB b25 = VecB(25); +static constexpr VecB b26 = VecB(26); +static constexpr VecB b27 = VecB(27); +static constexpr VecB b28 = VecB(28); +static constexpr VecB b29 = VecB(29); +static constexpr VecB b30 = VecB(30); +static constexpr VecB b31 = VecB(31); + +static constexpr VecH h0 = VecH(0); +static constexpr VecH h1 = VecH(1); +static constexpr VecH h2 = VecH(2); +static constexpr VecH h3 = VecH(3); +static constexpr VecH h4 = VecH(4); +static constexpr VecH h5 = VecH(5); +static constexpr VecH h6 = VecH(6); +static constexpr VecH h7 = VecH(7); +static constexpr VecH h8 = VecH(8); +static constexpr VecH h9 = VecH(9); +static constexpr VecH h10 = VecH(10); +static constexpr VecH h11 = VecH(11); +static constexpr VecH h12 = VecH(12); +static constexpr VecH h13 = VecH(13); +static constexpr VecH h14 = VecH(14); +static constexpr VecH h15 = VecH(15); +static constexpr VecH h16 = VecH(16); +static constexpr VecH h17 = VecH(17); +static constexpr VecH h18 = VecH(18); +static constexpr VecH h19 = VecH(19); +static constexpr VecH h20 = VecH(20); +static constexpr VecH h21 = VecH(21); +static constexpr VecH h22 = VecH(22); +static constexpr VecH h23 = VecH(23); +static constexpr VecH h24 = VecH(24); +static constexpr VecH h25 = VecH(25); +static constexpr VecH h26 = VecH(26); +static constexpr VecH h27 = VecH(27); +static constexpr VecH h28 = VecH(28); +static constexpr VecH h29 = VecH(29); +static constexpr VecH h30 = VecH(30); +static constexpr VecH h31 = VecH(31); + +static constexpr VecS s0 = VecS(0); +static constexpr VecS s1 = VecS(1); +static constexpr VecS s2 = VecS(2); +static constexpr VecS s3 = VecS(3); +static constexpr VecS s4 = VecS(4); +static constexpr VecS s5 = VecS(5); +static constexpr VecS s6 = VecS(6); +static constexpr VecS s7 = VecS(7); +static constexpr VecS s8 = VecS(8); +static constexpr VecS s9 = VecS(9); +static constexpr VecS s10 = VecS(10); +static constexpr VecS s11 = VecS(11); +static constexpr VecS s12 = VecS(12); +static constexpr VecS s13 = VecS(13); +static constexpr VecS s14 = VecS(14); +static constexpr VecS s15 = VecS(15); +static constexpr VecS s16 = VecS(16); +static constexpr VecS s17 = VecS(17); +static constexpr VecS s18 = VecS(18); +static constexpr VecS s19 = VecS(19); +static constexpr VecS s20 = VecS(20); +static constexpr VecS s21 = VecS(21); +static constexpr VecS s22 = VecS(22); +static constexpr VecS s23 = VecS(23); +static constexpr VecS s24 = VecS(24); +static constexpr VecS s25 = VecS(25); +static constexpr VecS s26 = VecS(26); +static constexpr VecS s27 = VecS(27); +static constexpr VecS s28 = VecS(28); +static constexpr VecS s29 = VecS(29); +static constexpr VecS s30 = VecS(30); +static constexpr VecS s31 = VecS(31); + +static constexpr VecD d0 = VecD(0); +static constexpr VecD d1 = VecD(1); +static constexpr VecD d2 = VecD(2); +static constexpr VecD d3 = VecD(3); +static constexpr VecD d4 = VecD(4); +static constexpr VecD d5 = VecD(5); +static constexpr VecD d6 = VecD(6); +static constexpr VecD d7 = VecD(7); +static constexpr VecD d8 = VecD(8); +static constexpr VecD d9 = VecD(9); +static constexpr VecD d10 = VecD(10); +static constexpr VecD d11 = VecD(11); +static constexpr VecD d12 = VecD(12); +static constexpr VecD d13 = VecD(13); +static constexpr VecD d14 = VecD(14); +static constexpr VecD d15 = VecD(15); +static constexpr VecD d16 = VecD(16); +static constexpr VecD d17 = VecD(17); +static constexpr VecD d18 = VecD(18); +static constexpr VecD d19 = VecD(19); +static constexpr VecD d20 = VecD(20); +static constexpr VecD d21 = VecD(21); +static constexpr VecD d22 = VecD(22); +static constexpr VecD d23 = VecD(23); +static constexpr VecD d24 = VecD(24); +static constexpr VecD d25 = VecD(25); +static constexpr VecD d26 = VecD(26); +static constexpr VecD d27 = VecD(27); +static constexpr VecD d28 = VecD(28); +static constexpr VecD d29 = VecD(29); +static constexpr VecD d30 = VecD(30); +static constexpr VecD d31 = VecD(31); + +static constexpr VecV q0 = VecV(0); +static constexpr VecV q1 = VecV(1); +static constexpr VecV q2 = VecV(2); +static constexpr VecV q3 = VecV(3); +static constexpr VecV q4 = VecV(4); +static constexpr VecV q5 = VecV(5); +static constexpr VecV q6 = VecV(6); +static constexpr VecV q7 = VecV(7); +static constexpr VecV q8 = VecV(8); +static constexpr VecV q9 = VecV(9); +static constexpr VecV q10 = VecV(10); +static constexpr VecV q11 = VecV(11); +static constexpr VecV q12 = VecV(12); +static constexpr VecV q13 = VecV(13); +static constexpr VecV q14 = VecV(14); +static constexpr VecV q15 = VecV(15); +static constexpr VecV q16 = VecV(16); +static constexpr VecV q17 = VecV(17); +static constexpr VecV q18 = VecV(18); +static constexpr VecV q19 = VecV(19); +static constexpr VecV q20 = VecV(20); +static constexpr VecV q21 = VecV(21); +static constexpr VecV q22 = VecV(22); +static constexpr VecV q23 = VecV(23); +static constexpr VecV q24 = VecV(24); +static constexpr VecV q25 = VecV(25); +static constexpr VecV q26 = VecV(26); +static constexpr VecV q27 = VecV(27); +static constexpr VecV q28 = VecV(28); +static constexpr VecV q29 = VecV(29); +static constexpr VecV q30 = VecV(30); +static constexpr VecV q31 = VecV(31); + +static constexpr VecV v0 = VecV(0); +static constexpr VecV v1 = VecV(1); +static constexpr VecV v2 = VecV(2); +static constexpr VecV v3 = VecV(3); +static constexpr VecV v4 = VecV(4); +static constexpr VecV v5 = VecV(5); +static constexpr VecV v6 = VecV(6); +static constexpr VecV v7 = VecV(7); +static constexpr VecV v8 = VecV(8); +static constexpr VecV v9 = VecV(9); +static constexpr VecV v10 = VecV(10); +static constexpr VecV v11 = VecV(11); +static constexpr VecV v12 = VecV(12); +static constexpr VecV v13 = VecV(13); +static constexpr VecV v14 = VecV(14); +static constexpr VecV v15 = VecV(15); +static constexpr VecV v16 = VecV(16); +static constexpr VecV v17 = VecV(17); +static constexpr VecV v18 = VecV(18); +static constexpr VecV v19 = VecV(19); +static constexpr VecV v20 = VecV(20); +static constexpr VecV v21 = VecV(21); +static constexpr VecV v22 = VecV(22); +static constexpr VecV v23 = VecV(23); +static constexpr VecV v24 = VecV(24); +static constexpr VecV v25 = VecV(25); +static constexpr VecV v26 = VecV(26); +static constexpr VecV v27 = VecV(27); +static constexpr VecV v28 = VecV(28); +static constexpr VecV v29 = VecV(29); +static constexpr VecV v30 = VecV(30); +static constexpr VecV v31 = VecV(31); + +#ifndef _DOXYGEN +} // {regs} + +// Make `a64::regs` accessible through `a64` namespace as well. +using namespace regs; +#endif + +//! \name Shift Operation Construction +//! \{ + +//! Constructs a `UXTB #value` extend and shift (unsigned byte extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxtb(uint32_t value) noexcept { return Shift(ShiftOp::kUXTB, value); } +//! Constructs a `UXTH #value` extend and shift (unsigned hword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxth(uint32_t value) noexcept { return Shift(ShiftOp::kUXTH, value); } +//! Constructs a `UXTW #value` extend and shift (unsigned word extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxtw(uint32_t value) noexcept { return Shift(ShiftOp::kUXTW, value); } +//! Constructs a `UXTX #value` extend and shift (unsigned dword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxtx(uint32_t value) noexcept { return Shift(ShiftOp::kUXTX, value); } + +//! Constructs a `SXTB #value` extend and shift (signed byte extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxtb(uint32_t value) noexcept { return Shift(ShiftOp::kSXTB, value); } +//! Constructs a `SXTH #value` extend and shift (signed hword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxth(uint32_t value) noexcept { return Shift(ShiftOp::kSXTH, value); } +//! Constructs a `SXTW #value` extend and shift (signed word extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxtw(uint32_t value) noexcept { return Shift(ShiftOp::kSXTW, value); } +//! Constructs a `SXTX #value` extend and shift (signed dword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxtx(uint32_t value) noexcept { return Shift(ShiftOp::kSXTX, value); } + +//! \} + +//! \name Memory Operand Construction +//! \{ + +//! Creates `[base, offset]` memory operand (offset mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, int32_t offset = 0) noexcept { + return Mem(base, offset); +} + +//! Creates `[base, offset]!` memory operand (pre-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_pre(const Gp& base, int32_t offset = 0) noexcept { + return Mem(base, offset, OperandSignature::fromValue(OffsetMode::kPreIndex)); +} + +//! Creates `[base], offset` memory operand (post-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_post(const Gp& base, int32_t offset = 0) noexcept { + return Mem(base, offset, OperandSignature::fromValue(OffsetMode::kPostIndex)); +} + +//! Creates `[base, index]` memory operand (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Gp& index) noexcept { + return Mem(base, index); +} + +//! Creates `[base, index]!` memory operand (pre-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_pre(const Gp& base, const Gp& index) noexcept { + return Mem(base, index, OperandSignature::fromValue(OffsetMode::kPreIndex)); +} + +//! Creates `[base], index` memory operand (post-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_post(const Gp& base, const Gp& index) noexcept { + return Mem(base, index, OperandSignature::fromValue(OffsetMode::kPostIndex)); +} + +//! Creates `[base, index, SHIFT_OP #shift]` memory operand (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Gp& index, const Shift& shift) noexcept { + return Mem(base, index, shift); +} + +//! Creates `[base, offset]` memory operand (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, int32_t offset = 0) noexcept { + return Mem(base, offset); +} + +// TODO: [ARM] PC + offset address. +#if 0 +//! Creates `[PC + offset]` (relative) memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const PC& pc, int32_t offset = 0) noexcept { + return Mem(pc, offset); +} +#endif + +//! \} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +//! \cond INTERNAL +ASMJIT_BEGIN_NAMESPACE +ASMJIT_DEFINE_TYPE_ID(a64::GpW, TypeId::kInt32); +ASMJIT_DEFINE_TYPE_ID(a64::GpX, TypeId::kInt64); +ASMJIT_DEFINE_TYPE_ID(a64::VecS, TypeId::kFloat32x1); +ASMJIT_DEFINE_TYPE_ID(a64::VecD, TypeId::kFloat64x1); +ASMJIT_DEFINE_TYPE_ID(a64::VecV, TypeId::kInt32x4); +ASMJIT_END_NAMESPACE +//! \endcond + +#endif // ASMJIT_ARM_A64OPERAND_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armglobals.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armglobals.h new file mode 100644 index 0000000000000000000000000000000000000000..851f67086014eb2b1426491c18a4c68d4fb3e7ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armglobals.h @@ -0,0 +1,17 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_ARMGLOBALS_H_INCLUDED +#define ASMJIT_ARM_ARMGLOBALS_H_INCLUDED + +#include "../core/archcommons.h" +#include "../core/inst.h" + +//! \namespace asmjit::arm +//! \ingroup asmjit_arm +//! +//! API shared between AArch32 & AArch64 backends. + +#endif // ASMJIT_ARM_ARMGLOBALS_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armoperand.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armoperand.h new file mode 100644 index 0000000000000000000000000000000000000000..583a3d8c326a7bdfef8aa27c6924794c11b4d8a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armoperand.h @@ -0,0 +1,396 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_ARMOPERAND_H_INCLUDED +#define ASMJIT_ARM_ARMOPERAND_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../arm/armglobals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(arm) + +//! \addtogroup asmjit_arm +//! \{ + +class Reg; +class Mem; + +//! Register traits (AArch32/AArch64). +//! +//! Register traits contains information about a particular register type. It's used by asmjit to setup register +//! information on-the-fly and to populate tables that contain register information (this way it's possible to +//! change register types and groups without having to reorder these tables). +template +struct RegTraits : public BaseRegTraits {}; + +//! \cond +// <--------------------+------------------------+------------------------+---+------------------+ +// | Reg-Type | Reg-Group |Sz | TypeId | +// <--------------------+------------------------+------------------------+---+------------------+ +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_GpW , RegGroup::kGp , 4 , TypeId::kInt32 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_GpX , RegGroup::kGp , 8 , TypeId::kInt64 ); // AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecB , RegGroup::kVec , 1 , TypeId::kVoid ); // AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecH , RegGroup::kVec , 2 , TypeId::kVoid ); // AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecS , RegGroup::kVec , 4 , TypeId::kInt32x1 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecD , RegGroup::kVec , 8 , TypeId::kInt32x2 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecQ , RegGroup::kVec , 16, TypeId::kInt32x4 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_PC , RegGroup::kPC , 8 , TypeId::kInt64 ); // AArch64 +//! \endcond + +//! Register operand that can represent AArch32 and AArch64 registers. +class Reg : public BaseReg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Reg, BaseReg) + + //! Gets whether the register is either `R` or `W` register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpR() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is either `R` or `W` register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpW() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is an `X` register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpX() const noexcept { return baseSignature() == RegTraits::kSignature; } + + //! Gets whether the register is a VEC-B register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecB() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-H register (16-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecH() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-S register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecS() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-D register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecD() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-Q register (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecQ() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is either VEC-D (64-bit) or VEC-Q (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecDOrQ() const noexcept { return uint32_t(type()) - uint32_t(RegType::kARM_VecD) <= 1u; } + //! Gets whether the register is a VEC-V register (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecV() const noexcept { return baseSignature() == RegTraits::kSignature; } + + //! Gets whether the register is an 8-bit vector register or view, alias if \ref isVecB(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec8() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 16-bit vector register or view, alias if \ref isVecH(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec16() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 32-bit vector register or view, alias if \ref isVecS(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec32() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 64-bit vector register or view, alias if \ref isVecD(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec64() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 128-bit vector register or view, alias if \ref isVecQ(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec128() const noexcept { return baseSignature() == RegTraits::kSignature; } + + template + ASMJIT_INLINE_NODEBUG void setRegT(uint32_t id) noexcept { + setSignature(RegTraits::kSignature); + setId(id); + } + + ASMJIT_INLINE_NODEBUG void setTypeAndId(RegType type, uint32_t id) noexcept { + setSignature(signatureOf(type)); + setId(id); + } + + static ASMJIT_INLINE_NODEBUG RegGroup groupOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kAArch64).regTypeToGroup(type); } + static ASMJIT_INLINE_NODEBUG TypeId typeIdOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kAArch64).regTypeToTypeId(type); } + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kAArch64).regTypeToSignature(type); } + + template + static ASMJIT_INLINE_NODEBUG RegGroup groupOfT() noexcept { return RegTraits::kGroup; } + + template + static ASMJIT_INLINE_NODEBUG TypeId typeIdOfT() noexcept { return RegTraits::kTypeId; } + + template + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfT() noexcept { return OperandSignature{RegTraits::kSignature}; } + + static ASMJIT_INLINE_NODEBUG bool isGpW(const Operand_& op) noexcept { return op.as().isGpW(); } + static ASMJIT_INLINE_NODEBUG bool isGpX(const Operand_& op) noexcept { return op.as().isGpX(); } + static ASMJIT_INLINE_NODEBUG bool isVecB(const Operand_& op) noexcept { return op.as().isVecB(); } + static ASMJIT_INLINE_NODEBUG bool isVecH(const Operand_& op) noexcept { return op.as().isVecH(); } + static ASMJIT_INLINE_NODEBUG bool isVecS(const Operand_& op) noexcept { return op.as().isVecS(); } + static ASMJIT_INLINE_NODEBUG bool isVecD(const Operand_& op) noexcept { return op.as().isVecD(); } + static ASMJIT_INLINE_NODEBUG bool isVecQ(const Operand_& op) noexcept { return op.as().isVecQ(); } + static ASMJIT_INLINE_NODEBUG bool isVecV(const Operand_& op) noexcept { return op.as().isVecV(); } + + static ASMJIT_INLINE_NODEBUG bool isGpW(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isGpW(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isGpX(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isGpX(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecB(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecB(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecH(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecH(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecS(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecS(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecD(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecD(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecQ(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecQ(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecV(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecV(op)) & unsigned(op.id() == id)); } +}; + +//! Vector register base - a common base for both AArch32 & AArch64 vector register. +class BaseVec : public Reg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(BaseVec, Reg) + + //! Additional signature bits used by a vector register. + enum AdditionalBits : uint32_t { + // Register element type (3 bits). + // |........|........|.XXX....|........| + kSignatureRegElementTypeShift = 12, + kSignatureRegElementTypeMask = 0x07 << kSignatureRegElementTypeShift, + + // Register has element index (1 bit). + // |........|........|X.......|........| + kSignatureRegElementFlagShift = 15, + kSignatureRegElementFlagMask = 0x01 << kSignatureRegElementFlagShift, + + // Register element index (4 bits). + // |........|....XXXX|........|........| + kSignatureRegElementIndexShift = 16, + kSignatureRegElementIndexMask = 0x0F << kSignatureRegElementIndexShift + }; + + //! Returns whether the register has element index (it's an element index access). + ASMJIT_INLINE_NODEBUG constexpr bool hasElementIndex() const noexcept { return _signature.hasField(); } + //! Returns element index of the register. + ASMJIT_INLINE_NODEBUG constexpr uint32_t elementIndex() const noexcept { return _signature.getField(); } + //! Sets element index of the register to `elementType`. + ASMJIT_INLINE_NODEBUG void setElementIndex(uint32_t elementIndex) noexcept { + _signature |= kSignatureRegElementFlagMask; + _signature.setField(elementIndex); + } + //! Resets element index of the register. + ASMJIT_INLINE_NODEBUG void resetElementIndex() noexcept { + _signature &= ~(kSignatureRegElementFlagMask | kSignatureRegElementIndexMask); + } +}; + +//! Memory operand (ARM). +class Mem : public BaseMem { +public: + //! \cond INTERNAL + //! Additional bits of operand's signature used by `arm::Mem`. + enum AdditionalBits : uint32_t { + // Index shift value (5 bits). + // |........|.....XXX|XX......|........| + kSignatureMemShiftValueShift = 14, + kSignatureMemShiftValueMask = 0x1Fu << kSignatureMemShiftValueShift, + + // Index shift operation (4 bits). + // |........|XXXX....|........|........| + kSignatureMemShiftOpShift = 20, + kSignatureMemShiftOpMask = 0x0Fu << kSignatureMemShiftOpShift, + + // Offset mode type (2 bits). + // |......XX|........|........|........| + kSignatureMemOffsetModeShift = 24, + kSignatureMemOffsetModeMask = 0x03u << kSignatureMemOffsetModeShift + }; + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Construct a default `Mem` operand, that points to [0]. + ASMJIT_INLINE_NODEBUG constexpr Mem() noexcept + : BaseMem() {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Mem& other) noexcept + : BaseMem(other) {} + + ASMJIT_INLINE_NODEBUG explicit Mem(Globals::NoInit_) noexcept + : BaseMem(Globals::NoInit) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Signature& signature, uint32_t baseId, uint32_t indexId, int32_t offset) noexcept + : BaseMem(signature, baseId, indexId, offset) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(const Label& base, int32_t off = 0, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(RegType::kLabelTag) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(const BaseReg& base, int32_t off = 0, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, const BaseReg& index, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromMemIndexType(index.type()) | + signature, base.id(), index.id(), 0) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, const BaseReg& index, const Shift& shift, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(uint32_t(shift.op())) | + Signature::fromValue(shift.value()) | + signature, base.id(), index.id(), 0) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(uint64_t base, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + signature, uint32_t(base >> 32), 0, int32_t(uint32_t(base & 0xFFFFFFFFu))) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Mem& operator=(const Mem& other) noexcept = default; + + //! \} + + //! \name Clone + //! \{ + + //! Clones the memory operand. + ASMJIT_INLINE_NODEBUG constexpr Mem clone() const noexcept { return Mem(*this); } + + //! Gets new memory operand adjusted by `off`. + ASMJIT_INLINE_NODEBUG Mem cloneAdjusted(int64_t off) const noexcept { + Mem result(*this); + result.addOffset(off); + return result; + } + + //! Clones the memory operand and makes it pre-index. + ASMJIT_INLINE_NODEBUG Mem pre() const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPreIndex); + return result; + } + + //! Clones the memory operand, applies a given offset `off` and makes it pre-index. + ASMJIT_INLINE_NODEBUG Mem pre(int64_t off) const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPreIndex); + result.addOffset(off); + return result; + } + + //! Clones the memory operand and makes it post-index. + ASMJIT_INLINE_NODEBUG Mem post() const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPostIndex); + return result; + } + + //! Clones the memory operand, applies a given offset `off` and makes it post-index. + ASMJIT_INLINE_NODEBUG Mem post(int64_t off) const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPostIndex); + result.addOffset(off); + return result; + } + + //! \} + + //! \name Base & Index + //! \{ + + //! Converts memory `baseType` and `baseId` to `arm::Reg` instance. + //! + //! The memory must have a valid base register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg baseReg() const noexcept { return Reg::fromTypeAndId(baseType(), baseId()); } + + //! Converts memory `indexType` and `indexId` to `arm::Reg` instance. + //! + //! The memory must have a valid index register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg indexReg() const noexcept { return Reg::fromTypeAndId(indexType(), indexId()); } + + using BaseMem::setIndex; + + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index, uint32_t shift) noexcept { + setIndex(index); + setShift(shift); + } + + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index, Shift shift) noexcept { + setIndex(index); + setShift(shift); + } + + //! \} + + //! \name ARM Specific Features + //! \{ + + //! Gets offset mode. + ASMJIT_INLINE_NODEBUG constexpr OffsetMode offsetMode() const noexcept { return OffsetMode(_signature.getField()); } + //! Sets offset mode to `mode`. + ASMJIT_INLINE_NODEBUG void setOffsetMode(OffsetMode mode) noexcept { _signature.setField(uint32_t(mode)); } + //! Resets offset mode to default (fixed offset, without write-back). + ASMJIT_INLINE_NODEBUG void resetOffsetMode() noexcept { _signature.setField(uint32_t(OffsetMode::kFixed)); } + + //! Tests whether the current memory offset mode is fixed (see \ref OffsetMode::kFixed). + ASMJIT_INLINE_NODEBUG constexpr bool isFixedOffset() const noexcept { return offsetMode() == OffsetMode::kFixed; } + //! Tests whether the current memory offset mode is either pre-index or post-index (write-back is used). + ASMJIT_INLINE_NODEBUG constexpr bool isPreOrPost() const noexcept { return offsetMode() != OffsetMode::kFixed; } + //! Tests whether the current memory offset mode is pre-index (write-back is used). + ASMJIT_INLINE_NODEBUG constexpr bool isPreIndex() const noexcept { return offsetMode() == OffsetMode::kPreIndex; } + //! Tests whether the current memory offset mode is post-index (write-back is used). + ASMJIT_INLINE_NODEBUG constexpr bool isPostIndex() const noexcept { return offsetMode() == OffsetMode::kPostIndex; } + + //! Sets offset mode of this memory operand to pre-index (write-back is used). + ASMJIT_INLINE_NODEBUG void makePreIndex() noexcept { setOffsetMode(OffsetMode::kPreIndex); } + //! Sets offset mode of this memory operand to post-index (write-back is used). + ASMJIT_INLINE_NODEBUG void makePostIndex() noexcept { setOffsetMode(OffsetMode::kPostIndex); } + + //! Gets shift operation that is used by index register. + ASMJIT_INLINE_NODEBUG constexpr ShiftOp shiftOp() const noexcept { return ShiftOp(_signature.getField()); } + //! Sets shift operation that is used by index register. + ASMJIT_INLINE_NODEBUG void setShiftOp(ShiftOp sop) noexcept { _signature.setField(uint32_t(sop)); } + //! Resets shift operation that is used by index register to LSL (default value). + ASMJIT_INLINE_NODEBUG void resetShiftOp() noexcept { _signature.setField(uint32_t(ShiftOp::kLSL)); } + + //! Gets whether the memory operand has shift (aka scale) constant. + ASMJIT_INLINE_NODEBUG constexpr bool hasShift() const noexcept { return _signature.hasField(); } + //! Gets the memory operand's shift (aka scale) constant. + ASMJIT_INLINE_NODEBUG constexpr uint32_t shift() const noexcept { return _signature.getField(); } + //! Sets the memory operand's shift (aka scale) constant. + ASMJIT_INLINE_NODEBUG void setShift(uint32_t shift) noexcept { _signature.setField(shift); } + + //! Sets the memory operand's shift and shift operation. + ASMJIT_INLINE_NODEBUG void setShift(Shift shift) noexcept { + _signature.setField(uint32_t(shift.op())); + _signature.setField(shift.value()); + } + + //! Resets the memory operand's shift (aka scale) constant to zero. + ASMJIT_INLINE_NODEBUG void resetShift() noexcept { _signature.setField(0); } + + //! \} +}; + +//! \name Shift Operation Construction +//! \{ + +//! Constructs a `LSL #value` shift (logical shift left). +static ASMJIT_INLINE_NODEBUG constexpr Shift lsl(uint32_t value) noexcept { return Shift(ShiftOp::kLSL, value); } +//! Constructs a `LSR #value` shift (logical shift right). +static ASMJIT_INLINE_NODEBUG constexpr Shift lsr(uint32_t value) noexcept { return Shift(ShiftOp::kLSR, value); } +//! Constructs a `ASR #value` shift (arithmetic shift right). +static ASMJIT_INLINE_NODEBUG constexpr Shift asr(uint32_t value) noexcept { return Shift(ShiftOp::kASR, value); } +//! Constructs a `ROR #value` shift (rotate right). +static ASMJIT_INLINE_NODEBUG constexpr Shift ror(uint32_t value) noexcept { return Shift(ShiftOp::kROR, value); } +//! Constructs a `RRX` shift (rotate with carry by 1). +static ASMJIT_INLINE_NODEBUG constexpr Shift rrx() noexcept { return Shift(ShiftOp::kRRX, 0); } +//! Constructs a `MSL #value` shift (logical shift left filling ones). +static ASMJIT_INLINE_NODEBUG constexpr Shift msl(uint32_t value) noexcept { return Shift(ShiftOp::kMSL, value); } + +//! \} + +//! \name Memory Operand Construction +//! \{ + +//! Creates `[base]` absolute memory operand (AArch32 or AArch64). +//! +//! \note The concept of absolute memory operands doesn't exist on ARM, the ISA only provides PC relative addressing. +//! Absolute memory operands can only be used if it's known that the PC relative offset is encodable and that it +//! would be within the limits. Absolute address is also often output from disassemblers, so AsmJit supports it to +//! make it possible to assemble such output back. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base) noexcept { return Mem(base); } + +//! \} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_ARMOPERAND_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armutils.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armutils.h new file mode 100644 index 0000000000000000000000000000000000000000..8241eda052eb5cc52606ab43634191f9deb8ae90 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/arm/armutils.h @@ -0,0 +1,226 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_ARMUTILS_H_INCLUDED +#define ASMJIT_ARM_ARMUTILS_H_INCLUDED + +#include "../core/support.h" +#include "../arm/armglobals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(arm) + +//! \addtogroup asmjit_arm +//! \{ + +//! Public utilities and helpers for targeting AArch32 and AArch64 architectures. +namespace Utils { + +//! Encodes a 12-bit immediate part of opcode that ise used by a standard 32-bit ARM encoding. +ASMJIT_MAYBE_UNUSED +static inline bool encodeAArch32Imm(uint64_t imm, uint32_t* encodedImmOut) noexcept { + if (imm & 0xFFFFFFFF00000000u) + return false; + + uint32_t v = uint32_t(imm); + uint32_t r = 0; + + if (v <= 0xFFu) { + *encodedImmOut = v; + return true; + } + + // Rotate if there are bits on both ends (LSB and MSB) + // (otherwise we would not be able to calculate the rotation with ctz). + if (v & 0xFF0000FFu) { + v = Support::ror(v, 16); + r = 16u; + } + + uint32_t n = Support::ctz(v) & ~0x1u; + r = (r - n) & 0x1Eu; + v = Support::ror(v, n); + + if (v > 0xFFu) + return false; + + *encodedImmOut = v | (r << 7); + return true; +} + +//! Decomposed fields of a logical immediate value. +struct LogicalImm { + uint32_t n; + uint32_t s; + uint32_t r; +}; + +//! Encodes the given `imm` value of the given `width` to a logical immediate value represented as N, S, and R fields +//! and writes these fields to `out`. +//! +//! Encoding Table: +//! +//! ``` +//! +---+--------+--------+------+ +//! | N | ImmS | ImmR | Size | +//! +---+--------+--------+------+ +//! | 1 | ssssss | rrrrrr | 64 | +//! | 0 | 0sssss | .rrrrr | 32 | +//! | 0 | 10ssss | ..rrrr | 16 | +//! | 0 | 110sss | ...rrr | 8 | +//! | 0 | 1110ss | ....rr | 4 | +//! | 0 | 11110s | .....r | 2 | +//! +---+--------+--------+------+ +//! ``` +ASMJIT_MAYBE_UNUSED +static bool encodeLogicalImm(uint64_t imm, uint32_t width, LogicalImm* out) noexcept { + // Determine the element width, which must be 2, 4, 8, 16, 32, or 64 bits. + do { + width /= 2; + uint64_t mask = (uint64_t(1) << width) - 1u; + if ((imm & mask) != ((imm >> width) & mask)) { + width *= 2; + break; + } + } while (width > 2); + + // Patterns of all zeros and all ones are not encodable. + uint64_t lsbMask = Support::lsbMask(width); + imm &= lsbMask; + + if (imm == 0 || imm == lsbMask) + return false; + + // Inspect the pattern and get the most important bit indexes. + // + // oIndex <-+ +-> zIndex + // | | + // |..zeros..|oCount|zCount|..ones..| + // |000000000|111111|000000|11111111| + + uint32_t zIndex = Support::ctz(~imm); + uint64_t zImm = imm ^ ((uint64_t(1) << zIndex) - 1); + uint32_t zCount = (zImm ? Support::ctz(zImm) : width) - zIndex; + + uint32_t oIndex = zIndex + zCount; + uint64_t oImm = ~(zImm ^ Support::lsbMask(oIndex)); + uint32_t oCount = (oImm ? Support::ctz(oImm) : width) - (oIndex); + + // Verify whether the bit-pattern is encodable. + uint64_t mustBeZero = oImm ^ ~Support::lsbMask(oIndex + oCount); + if (mustBeZero != 0 || (zIndex > 0 && width - (oIndex + oCount) != 0)) + return false; + + out->n = width == 64; + out->s = (oCount + zIndex - 1) | (Support::neg(width * 2) & 0x3F); + out->r = width - oIndex; + return true; +} + +//! Returns true if the given `imm` value is encodable as a logical immediate. The `width` argument describes the +//! width of the operation, and must be either 32 or 64. This function can be used to test whether an immediate +//! value can be used with AND, ANDS, BIC, BICS, EON, EOR, ORN, and ORR instruction. +ASMJIT_MAYBE_UNUSED +static ASMJIT_INLINE_NODEBUG bool isLogicalImm(uint64_t imm, uint32_t width) noexcept { + LogicalImm dummy; + return encodeLogicalImm(imm, width, &dummy); +} + +//! Returns true if the given `imm` value is encodable as an immediate with `add` and `sub` instructions on AArch64. +//! These two instructions can encode 12-bit immediate value optionally shifted left by 12 bits. +ASMJIT_MAYBE_UNUSED +static ASMJIT_INLINE_NODEBUG bool isAddSubImm(uint64_t imm) noexcept { + return imm <= 0xFFFu || (imm & ~uint64_t(0xFFFu << 12)) == 0; +} + +//! Returns true if the given `imm` value is a byte mask. Byte mask has each byte part of the value set to either +//! 0x00 or 0xFF. Some ARM instructions accept immediates that form a byte-mask and this function can be used to +//! verify that the immediate is encodable before using the value. +template +static ASMJIT_INLINE_NODEBUG bool isByteMaskImm8(const T& imm) noexcept { + constexpr T kMask = T(0x0101010101010101 & Support::allOnes()); + return imm == (imm & kMask) * T(255); +} + +// [.......A|B.......|.......C|D.......|.......E|F.......|.......G|H.......] +static ASMJIT_INLINE_NODEBUG uint32_t encodeImm64ByteMaskToImm8(uint64_t imm) noexcept { + return uint32_t(((imm >> (7 - 0)) & 0b00000011) | // [.......G|H.......] + ((imm >> (23 - 2)) & 0b00001100) | // [.......E|F.......] + ((imm >> (39 - 4)) & 0b00110000) | // [.......C|D.......] + ((imm >> (55 - 6)) & 0b11000000)); // [.......A|B.......] +} +//! \cond +//! A generic implementation that checjs whether a floating point value can be converted to ARM Imm8. +template +static ASMJIT_FORCE_INLINE bool isFPImm8Generic(T val) noexcept { + constexpr uint32_t kAllBsMask = Support::lsbMask(kNumBBits); + constexpr uint32_t kB0Pattern = Support::bitMask(kNumBBits - 1); + constexpr uint32_t kB1Pattern = kAllBsMask ^ kB0Pattern; + + T immZ = val & Support::lsbMask(kNumZeroBits); + uint32_t immB = uint32_t(val >> (kNumZeroBits + kNumCDEFGHBits)) & kAllBsMask; + + // ImmZ must be all zeros and ImmB must either be B0 or B1 pattern. + return immZ == 0 && (immB == kB0Pattern || immB == kB1Pattern); +} +//! \endcond + +//! Returns true if the given half precision floating point `val` can be encoded as ARM IMM8 value, which represents +//! a limited set of floating point immediate values, which can be used with FMOV instruction. +//! +//! The floating point must have bits distributed in the following way: +//! +//! ``` +//! [aBbbcdef|gh000000] +//! ``` +static ASMJIT_INLINE_NODEBUG bool isFP16Imm8(uint32_t val) noexcept { return isFPImm8Generic(val); } + +//! Returns true if the given single precision floating point `val` can be encoded as ARM IMM8 value, which represents +//! a limited set of floating point immediate values, which can be used with FMOV instruction. +//! +//! The floating point must have bits distributed in the following way: +//! +//! ``` +//! [aBbbbbbc|defgh000|00000000|00000000] +//! ``` +static ASMJIT_INLINE_NODEBUG bool isFP32Imm8(uint32_t val) noexcept { return isFPImm8Generic(val); } +//! \overload +static ASMJIT_INLINE_NODEBUG bool isFP32Imm8(float val) noexcept { return isFP32Imm8(Support::bitCast(val)); } + +//! Returns true if the given double precision floating point `val` can be encoded as ARM IMM8 value, which represents +//! a limited set of floating point immediate values, which can be used with FMOV instruction. +//! +//! The floating point must have bits distributed in the following way: +//! +//! ``` +//! [aBbbbbbb|bbcdefgh|00000000|00000000|00000000|00000000|00000000|00000000] +//! ``` +static ASMJIT_INLINE_NODEBUG bool isFP64Imm8(uint64_t val) noexcept { return isFPImm8Generic(val); } +//! \overload +static ASMJIT_INLINE_NODEBUG bool isFP64Imm8(double val) noexcept { return isFP64Imm8(Support::bitCast(val)); } + +//! \cond +template +static ASMJIT_INLINE_NODEBUG uint32_t encodeFPToImm8Generic(T val) noexcept { + uint32_t bits = uint32_t(val >> kNumZeroBits); + return ((bits >> (kNumBBits + kNumCDEFGHBits - 7)) & 0x80u) | (bits & 0x7F); +} +//! \endcond + +//! Encodes a double precision floating point value into IMM8 format. +//! +//! \note This function expects that `isFP64Imm8(val) == true` so it doesn't perform any checks of the value and just +//! rearranges some bits into Imm8 order. +static ASMJIT_INLINE_NODEBUG uint32_t encodeFP64ToImm8(uint64_t val) noexcept { return encodeFPToImm8Generic(val); } +//! \overload +static ASMJIT_INLINE_NODEBUG uint32_t encodeFP64ToImm8(double val) noexcept { return encodeFP64ToImm8(Support::bitCast(val)); } + +} // {Utils} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_ARMUTILS_H_INCLUDED + diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/api-config.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/api-config.h new file mode 100644 index 0000000000000000000000000000000000000000..6c09b348db58ad77da3b67356ee1e8d60705d074 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/api-config.h @@ -0,0 +1,664 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_API_CONFIG_H_INCLUDED +#define ASMJIT_CORE_API_CONFIG_H_INCLUDED + +// AsmJit Library & ABI Version +// ============================ + +//! \addtogroup asmjit_core +//! \{ + +//! Makes a 32-bit integer that represents AsmJit version in `(major << 16) | (minor << 8) | patch` form. +#define ASMJIT_LIBRARY_MAKE_VERSION(major, minor, patch) ((major << 16) | (minor << 8) | (patch)) + +//! AsmJit library version, see \ref ASMJIT_LIBRARY_MAKE_VERSION for a version format reference. +#define ASMJIT_LIBRARY_VERSION ASMJIT_LIBRARY_MAKE_VERSION(1, 13, 0) + +//! \def ASMJIT_ABI_NAMESPACE +//! +//! AsmJit ABI namespace is an inline namespace within \ref asmjit namespace. +//! +//! It's used to make sure that when user links to an incompatible version of AsmJit, it won't link. It has also +//! some additional properties as well. When `ASMJIT_ABI_NAMESPACE` is defined by the user it would override the +//! AsmJit default, which makes it possible to use multiple AsmJit libraries within a single project, totally +//! controlled by users. This is useful especially in cases in which some of such library comes from third party. +#if !defined(ASMJIT_ABI_NAMESPACE) + #define ASMJIT_ABI_NAMESPACE _abi_1_13 +#endif // !ASMJIT_ABI_NAMESPACE + +//! \} + +// Global Dependencies +// =================== + +#include +#include +#include // We really want std types as globals, not under 'std' namespace. +#include +#include +#include + +#include +#include +#include +#include + +#if !defined(_WIN32) && !defined(__EMSCRIPTEN__) + #include +#endif + +// Build Options +// ============= + +// NOTE: Doxygen cannot document macros that are not defined, that's why we have to define them and then undefine +// them immediately, so it won't use the macros with its own preprocessor. +#ifdef _DOXYGEN +namespace asmjit { + +//! \addtogroup asmjit_build +//! \{ + +//! Asmjit is embedded, implies \ref ASMJIT_STATIC. +#define ASMJIT_EMBED + +//! Enables static-library build. +#define ASMJIT_STATIC + +//! Defined when AsmJit's build configuration is 'Debug'. +//! +//! \note Can be defined explicitly to bypass auto-detection. +#define ASMJIT_BUILD_DEBUG + +//! Defined when AsmJit's build configuration is 'Release'. +//! +//! \note Can be defined explicitly to bypass auto-detection. +#define ASMJIT_BUILD_RELEASE + +//! Disables X86/X64 backends. +#define ASMJIT_NO_X86 + +//! Disables AArch64 backend. +#define ASMJIT_NO_AARCH64 + +//! Disables non-host backends entirely (useful for JIT compilers to minimize the library size). +#define ASMJIT_NO_FOREIGN + +//! Disables deprecated API at compile time (deprecated API won't be available). +#define ASMJIT_NO_DEPRECATED + +//! Disables \ref asmjit_builder functionality completely. +#define ASMJIT_NO_BUILDER + +//! Disables \ref asmjit_compiler functionality completely. +#define ASMJIT_NO_COMPILER + +//! Disables JIT memory management and \ref asmjit::JitRuntime. +#define ASMJIT_NO_JIT + +//! Disables \ref asmjit::Logger and \ref asmjit::Formatter. +#define ASMJIT_NO_LOGGING + +//! Disables everything that contains text. +#define ASMJIT_NO_TEXT + +//! Disables instruction validation API. +#define ASMJIT_NO_VALIDATION + +//! Disables instruction introspection API. +#define ASMJIT_NO_INTROSPECTION + +// Avoid doxygen preprocessor using feature-selection definitions. +#undef ASMJIT_BUILD_EMBED +#undef ASMJIT_BUILD_STATIC +#undef ASMJIT_BUILD_DEBUG +#undef ASMJIT_BUILD_RELEASE +#undef ASMJIT_NO_X86 +#undef ASMJIT_NO_FOREIGN +// (keep ASMJIT_NO_DEPRECATED defined, we don't document deprecated APIs). +#undef ASMJIT_NO_BUILDER +#undef ASMJIT_NO_COMPILER +#undef ASMJIT_NO_JIT +#undef ASMJIT_NO_LOGGING +#undef ASMJIT_NO_TEXT +#undef ASMJIT_NO_VALIDATION +#undef ASMJIT_NO_INTROSPECTION + +//! \} + +} // {asmjit} +#endif // _DOXYGEN + +// ASMJIT_NO_BUILDER implies ASMJIT_NO_COMPILER. +#if defined(ASMJIT_NO_BUILDER) && !defined(ASMJIT_NO_COMPILER) + #define ASMJIT_NO_COMPILER +#endif + +// Prevent compile-time errors caused by misconfiguration. +#if defined(ASMJIT_NO_TEXT) && !defined(ASMJIT_NO_LOGGING) + #pragma message("'ASMJIT_NO_TEXT' can only be defined when 'ASMJIT_NO_LOGGING' is defined.") + #undef ASMJIT_NO_TEXT +#endif + +#if defined(ASMJIT_NO_INTROSPECTION) && !defined(ASMJIT_NO_COMPILER) + #pragma message("'ASMJIT_NO_INTROSPECTION' can only be defined when 'ASMJIT_NO_COMPILER' is defined") + #undef ASMJIT_NO_INTROSPECTION +#endif + +// Build Mode +// ========== + +// Detect ASMJIT_BUILD_DEBUG and ASMJIT_BUILD_RELEASE if not defined. +#if !defined(ASMJIT_BUILD_DEBUG) && !defined(ASMJIT_BUILD_RELEASE) + #if !defined(NDEBUG) + #define ASMJIT_BUILD_DEBUG + #else + #define ASMJIT_BUILD_RELEASE + #endif +#endif + +// Target Architecture Detection +// ============================= + +//! \addtogroup asmjit_core +//! \{ + +//! \def ASMJIT_ARCH_X86 +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is X86 (32) or X86_64 (64). + +//! \def ASMJIT_ARCH_ARM +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is ARM (32) or AArch64 (64). + +//! \def ASMJIT_ARCH_MIPS +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is MIPS (32) or MISP64 (64). + +//! \def ASMJIT_ARCH_RISCV +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is RV32 (32) or RV64 (64). + +//! \def ASMJIT_ARCH_BITS +//! +//! Defined to either 32 or 64 depending on the target. + +//! \def ASMJIT_ARCH_LE +//! +//! Defined to 1 if the target architecture is little endian. + +//! \def ASMJIT_ARCH_BE +//! +//! Defined to 1 if the target architecture is big endian. + +//! \} + +//! \cond NONE + +#if defined(_M_X64) || defined(__x86_64__) + #define ASMJIT_ARCH_X86 64 +#elif defined(_M_IX86) || defined(__X86__) || defined(__i386__) + #define ASMJIT_ARCH_X86 32 +#else + #define ASMJIT_ARCH_X86 0 +#endif + +#if defined(_M_ARM64) || defined(__arm64__) || defined(__aarch64__) +# define ASMJIT_ARCH_ARM 64 +#elif defined(_M_ARM) || defined(_M_ARMT) || defined(__arm__) || defined(__thumb__) || defined(__thumb2__) + #define ASMJIT_ARCH_ARM 32 +#else + #define ASMJIT_ARCH_ARM 0 +#endif + +#if defined(_MIPS_ARCH_MIPS64) || defined(__mips64) + #define ASMJIT_ARCH_MIPS 64 +#elif defined(_MIPS_ARCH_MIPS32) || defined(_M_MRX000) || defined(__mips__) + #define ASMJIT_ARCH_MIPS 32 +#else + #define ASMJIT_ARCH_MIPS 0 +#endif + +// NOTE `__riscv` is the correct macro in this case as specified by "RISC-V Toolchain Conventions". +#if (defined(__riscv) || defined(__riscv__)) && defined(__riscv_xlen) + #define ASMJIT_ARCH_RISCV __riscv_xlen +#else + #define ASMJIT_ARCH_RISCV 0 +#endif + +#define ASMJIT_ARCH_BITS (ASMJIT_ARCH_X86 | ASMJIT_ARCH_ARM | ASMJIT_ARCH_MIPS | ASMJIT_ARCH_RISCV) +#if ASMJIT_ARCH_BITS == 0 + #undef ASMJIT_ARCH_BITS + #if defined (__LP64__) || defined(_LP64) + #define ASMJIT_ARCH_BITS 64 + #else + #define ASMJIT_ARCH_BITS 32 + #endif +#endif + +#if (defined(__ARMEB__)) || \ + (defined(__MIPSEB__)) || \ + (defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)) + #define ASMJIT_ARCH_LE 0 + #define ASMJIT_ARCH_BE 1 +#else + #define ASMJIT_ARCH_LE 1 + #define ASMJIT_ARCH_BE 0 +#endif + +#if defined(ASMJIT_NO_FOREIGN) + #if !ASMJIT_ARCH_X86 && !defined(ASMJIT_NO_X86) + #define ASMJIT_NO_X86 + #endif + + #if ASMJIT_ARCH_ARM != 64 && !defined(ASMJIT_NO_AARCH64) + #define ASMJIT_NO_AARCH64 + #endif +#endif + +//! \endcond + +// C++ Compiler and Features Detection +// =================================== + +#if defined(__GNUC__) && defined(__has_attribute) + #define ASMJIT_CXX_HAS_ATTRIBUTE(NAME, CHECK) (__has_attribute(NAME)) +#else + #define ASMJIT_CXX_HAS_ATTRIBUTE(NAME, CHECK) (!(!(CHECK))) +#endif // !ASMJIT_CXX_HAS_ATTRIBUTE + +// API Decorators & C++ Extensions +// =============================== + +//! \addtogroup asmjit_core +//! \{ + +//! \def ASMJIT_API +//! +//! A decorator that is used to decorate API that AsmJit exports when built as a shared library. + +//! \def ASMJIT_VIRTAPI +//! +//! This is basically a workaround. When using MSVC and marking class as DLL export everything gets exported, which +//! is unwanted in most projects. MSVC automatically exports typeinfo and vtable if at least one symbol of the class +//! is exported. However, GCC has some strange behavior that even if one or more symbol is exported it doesn't export +//! typeinfo unless the class itself is decorated with "visibility(default)" (i.e. ASMJIT_API). + +//! \def ASMJIT_FORCE_INLINE +//! +//! Decorator to force inlining of functions, uses either `__attribute__((__always_inline__))` or __forceinline, +//! depending on C++ compiler. + +//! \def ASMJIT_INLINE_NODEBUG +//! +//! Like \ref ASMJIT_FORCE_INLINE, but uses additionally `__nodebug__` or `__artificial__` attribute to make the +//! debugging of some AsmJit functions easier, especially getters and one-line abstractions where usually you don't +//! want to step in. + +//! \def ASMJIT_NOINLINE +//! +//! Decorator to avoid inlining of functions, uses either `__attribute__((__noinline__))` or `__declspec(noinline)` +//! depending on C++ compiler. + +//! \def ASMJIT_NORETURN +//! +//! Decorator that marks functions that should never return. Typically used to implement assertion handlers that +//! terminate, so the function never returns. + +//! \def ASMJIT_CDECL +//! +//! CDECL function attribute - either `__attribute__((__cdecl__))` or `__cdecl`. + +//! \def ASMJIT_STDCALL +//! +//! STDCALL function attribute - either `__attribute__((__stdcall__))` or `__stdcall`. +//! +//! \note This expands to nothing on non-x86 targets as STDCALL is X86 specific. + +//! \def ASMJIT_FASTCALL +//! +//! FASTCALL function attribute - either `__attribute__((__fastcall__))` or `__fastcall`. +//! +//! \note Expands to nothing on non-x86 targets as FASTCALL is X86 specific. + +//! \def ASMJIT_REGPARM(N) +//! +//! Expands to `__attribute__((__regparm__(N)))` when compiled by GCC or clang, nothing otherwise. + +//! \def ASMJIT_VECTORCALL +//! +//! VECTORCALL function attribute - either `__attribute__((__vectorcall__))` or `__vectorcall`. +//! +//! \note Expands to nothing on non-x86 targets as VECTORCALL is X86 specific. + +//! \} + +// API (Export / Import). +#if !defined(ASMJIT_STATIC) + #if defined(_WIN32) && (defined(_MSC_VER) || defined(__MINGW32__)) + #ifdef ASMJIT_EXPORTS + #define ASMJIT_API __declspec(dllexport) + #else + #define ASMJIT_API __declspec(dllimport) + #endif + #elif defined(_WIN32) && defined(__GNUC__) + #ifdef ASMJIT_EXPORTS + #define ASMJIT_API __attribute__((__dllexport__)) + #else + #define ASMJIT_API __attribute__((__dllimport__)) + #endif + #elif defined(__GNUC__) + #define ASMJIT_API __attribute__((__visibility__("default"))) + #endif +#endif + +#if !defined(ASMJIT_API) + #define ASMJIT_API +#endif + +#if !defined(ASMJIT_VARAPI) + #define ASMJIT_VARAPI extern ASMJIT_API +#endif + +#if defined(__GNUC__) && !defined(_WIN32) + #define ASMJIT_VIRTAPI ASMJIT_API +#else + #define ASMJIT_VIRTAPI +#endif + +// Function attributes. +#if !defined(ASMJIT_BUILD_DEBUG) && defined(__GNUC__) + #define ASMJIT_FORCE_INLINE inline __attribute__((__always_inline__)) +#elif !defined(ASMJIT_BUILD_DEBUG) && defined(_MSC_VER) + #define ASMJIT_FORCE_INLINE __forceinline +#else + #define ASMJIT_FORCE_INLINE inline +#endif + + +#if defined(__clang__) + #define ASMJIT_INLINE_NODEBUG inline __attribute__((__always_inline__, __nodebug__)) +#elif defined(__GNUC__) + #define ASMJIT_INLINE_NODEBUG inline __attribute__((__always_inline__, __artificial__)) +#else + #define ASMJIT_INLINE_NODEBUG inline +#endif + +#if defined(__GNUC__) + #define ASMJIT_NOINLINE __attribute__((__noinline__)) + #define ASMJIT_NORETURN __attribute__((__noreturn__)) +#elif defined(_MSC_VER) + #define ASMJIT_NOINLINE __declspec(noinline) + #define ASMJIT_NORETURN __declspec(noreturn) +#else + #define ASMJIT_NOINLINE + #define ASMJIT_NORETURN +#endif + +// Calling conventions. +#if ASMJIT_ARCH_X86 == 32 && defined(__GNUC__) + #define ASMJIT_CDECL __attribute__((__cdecl__)) + #define ASMJIT_STDCALL __attribute__((__stdcall__)) + #define ASMJIT_FASTCALL __attribute__((__fastcall__)) + #define ASMJIT_REGPARM(N) __attribute__((__regparm__(N))) +#elif ASMJIT_ARCH_X86 == 32 && defined(_MSC_VER) + #define ASMJIT_CDECL __cdecl + #define ASMJIT_STDCALL __stdcall + #define ASMJIT_FASTCALL __fastcall + #define ASMJIT_REGPARM(N) +#else + #define ASMJIT_CDECL + #define ASMJIT_STDCALL + #define ASMJIT_FASTCALL + #define ASMJIT_REGPARM(N) +#endif + +#if ASMJIT_ARCH_X86 && defined(_WIN32) && defined(_MSC_VER) + #define ASMJIT_VECTORCALL __vectorcall +#elif ASMJIT_ARCH_X86 && defined(_WIN32) + #define ASMJIT_VECTORCALL __attribute__((__vectorcall__)) +#else + #define ASMJIT_VECTORCALL +#endif + +// Type alignment (not allowed by C++11 'alignas' keyword). +#if defined(__GNUC__) + #define ASMJIT_ALIGN_TYPE(TYPE, N) __attribute__((__aligned__(N))) TYPE +#elif defined(_MSC_VER) + #define ASMJIT_ALIGN_TYPE(TYPE, N) __declspec(align(N)) TYPE +#else + #define ASMJIT_ALIGN_TYPE(TYPE, N) TYPE +#endif + +//! \def ASMJIT_MAY_ALIAS +//! +//! Expands to `__attribute__((__may_alias__))` if supported. +#if defined(__GNUC__) + #define ASMJIT_MAY_ALIAS __attribute__((__may_alias__)) +#else + #define ASMJIT_MAY_ALIAS +#endif + +//! \def ASMJIT_MAYBE_UNUSED +//! +//! Expands to `[[maybe_unused]]` if supported or a compiler attribute instead. +#if __cplusplus >= 201703L + #define ASMJIT_MAYBE_UNUSED [[maybe_unused]] +#elif defined(__GNUC__) + #define ASMJIT_MAYBE_UNUSED __attribute__((unused)) +#else + #define ASMJIT_MAYBE_UNUSED +#endif + +#if defined(__clang_major__) && __clang_major__ >= 4 && !defined(_DOXYGEN) + // NOTE: Clang allows to apply this attribute to function arguments, which is what we want. Once GCC decides to + // support this use, we will enable it for GCC as well. However, until that, it will be clang only, which is + // what we need for static analysis. + #define ASMJIT_NONNULL(FUNCTION_ARGUMENT) FUNCTION_ARGUMENT __attribute__((__nonnull__)) +#else + #define ASMJIT_NONNULL(FUNCTION_ARGUMENT) FUNCTION_ARGUMENT +#endif + +//! \def ASMJIT_NOEXCEPT_TYPE +//! +//! Defined to `noexcept` in C++17 mode or nothing otherwise. Used by function typedefs. +#if __cplusplus >= 201703L + #define ASMJIT_NOEXCEPT_TYPE noexcept +#else + #define ASMJIT_NOEXCEPT_TYPE +#endif + +//! \def ASMJIT_ASSUME(...) +//! +//! Macro that tells the C/C++ compiler that the expression `...` evaluates to true. +//! +//! This macro has two purposes: +//! +//! 1. Enable optimizations that would not be possible without the assumption. +//! 2. Hint static analysis tools that a certain condition is true to prevent false positives. +#if defined(__clang__) + #define ASMJIT_ASSUME(...) __builtin_assume(__VA_ARGS__) +#elif defined(__GNUC__) + #define ASMJIT_ASSUME(...) do { if (!(__VA_ARGS__)) __builtin_unreachable(); } while (0) +#elif defined(_MSC_VER) + #define ASMJIT_ASSUME(...) __assume(__VA_ARGS__) +#else + #define ASMJIT_ASSUME(...) (void)0 +#endif + +//! \def ASMJIT_LIKELY(...) +//! +//! Condition is likely to be taken (mostly error handling and edge cases). + +//! \def ASMJIT_UNLIKELY(...) +//! +//! Condition is unlikely to be taken (mostly error handling and edge cases). +#if defined(__GNUC__) + #define ASMJIT_LIKELY(...) __builtin_expect(!!(__VA_ARGS__), 1) + #define ASMJIT_UNLIKELY(...) __builtin_expect(!!(__VA_ARGS__), 0) +#else + #define ASMJIT_LIKELY(...) (__VA_ARGS__) + #define ASMJIT_UNLIKELY(...) (__VA_ARGS__) +#endif + +//! \def ASMJIT_FALLTHROUGH +//! +//! Portable [[fallthrough]] attribute. +#if defined(__clang__) && __cplusplus >= 201103L + #define ASMJIT_FALLTHROUGH [[clang::fallthrough]] +#elif defined(__GNUC__) && __GNUC__ >= 7 + #define ASMJIT_FALLTHROUGH __attribute__((__fallthrough__)) +#else + #define ASMJIT_FALLTHROUGH ((void)0) /* fallthrough */ +#endif + +//! \def ASMJIT_DEPRECATED +//! +//! Marks function, class, struct, enum, or anything else as deprecated. +#if defined(__GNUC__) + #define ASMJIT_DEPRECATED(MESSAGE) __attribute__((__deprecated__(MESSAGE))) +#elif defined(_MSC_VER) + #define ASMJIT_DEPRECATED(MESSAGE) __declspec(deprecated(MESSAGE)) +#else + #define ASMJIT_DEPRECATED(MESSAGE) +#endif + +// Utilities. +#define ASMJIT_OFFSET_OF(STRUCT, MEMBER) ((int)(intptr_t)((const char*)&((const STRUCT*)0x100)->MEMBER) - 0x100) +#define ASMJIT_ARRAY_SIZE(X) uint32_t(sizeof(X) / sizeof(X[0])) + +#if ASMJIT_CXX_HAS_ATTRIBUTE(no_sanitize, 0) + #define ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF __attribute__((__no_sanitize__("undefined"))) +#elif defined(__GNUC__) && __GNUC__ >= 5 + #define ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF __attribute__((__no_sanitize_undefined__)) +#else + #define ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF +#endif + +// Diagnostic Macros +// ====================================== + +#if !defined(__clang__) && !defined(__INTEL_COMPILER) && !defined(_DOXYGEN) + #if defined(__GNUC__) && __GNUC__ == 4 + // There is a bug in GCC 4.X that has been fixed in GCC 5+, so just silence the warning. + #define ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wmissing-field-initializers\"") + #define ASMJIT_END_DIAGNOSTIC_SCOPE \ + _Pragma("GCC diagnostic pop") + #elif defined(_MSC_VER) + #define ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + __pragma(warning(push)) \ + __pragma(warning(disable: 4127)) /* conditional expression is const */ \ + __pragma(warning(disable: 4201)) /* nameless struct/union */ + #define ASMJIT_END_DIAGNOSTIC_SCOPE \ + __pragma(warning(pop)) + #endif +#endif + +#if !defined(ASMJIT_BEGIN_DIAGNOSTIC_SCOPE) && !defined(ASMJIT_END_DIAGNOSTIC_SCOPE) + #define ASMJIT_BEGIN_DIAGNOSTIC_SCOPE + #define ASMJIT_END_DIAGNOSTIC_SCOPE +#endif + +// Begin-Namespace & End-Namespace Macros +// ====================================== + +#if !defined(ASMJIT_NO_ABI_NAMESPACE) && !defined(_DOXYGEN) + #define ASMJIT_BEGIN_NAMESPACE \ + ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + namespace asmjit { \ + inline namespace ASMJIT_ABI_NAMESPACE { + #define ASMJIT_END_NAMESPACE \ + }} \ + ASMJIT_END_DIAGNOSTIC_SCOPE +#else + #define ASMJIT_BEGIN_NAMESPACE \ + ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + namespace asmjit { + #define ASMJIT_END_NAMESPACE \ + } \ + ASMJIT_END_DIAGNOSTIC_SCOPE +#endif + +#define ASMJIT_BEGIN_SUB_NAMESPACE(NAMESPACE) ASMJIT_BEGIN_NAMESPACE namespace NAMESPACE { +#define ASMJIT_END_SUB_NAMESPACE } ASMJIT_END_NAMESPACE + +// C++ Utilities +// ============= + +#define ASMJIT_NONCOPYABLE(Type) \ + Type(const Type& other) = delete; \ + Type& operator=(const Type& other) = delete; + +#define ASMJIT_NONCONSTRUCTIBLE(Type) \ + Type() = delete; \ + Type(const Type& other) = delete; \ + Type& operator=(const Type& other) = delete; + +//! \def ASMJIT_DEFINE_ENUM_FLAGS(T) +//! +//! Defines bit operations for enumeration flags. +#ifdef _DOXYGEN + #define ASMJIT_DEFINE_ENUM_FLAGS(T) +#else + #define ASMJIT_DEFINE_ENUM_FLAGS(T) \ + static ASMJIT_INLINE_NODEBUG constexpr T operator~(T a) noexcept { \ + return T(~(std::underlying_type::type)(a)); \ + } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr T operator|(T a, T b) noexcept { \ + return T((std::underlying_type::type)(a) | \ + (std::underlying_type::type)(b)); \ + } \ + static ASMJIT_INLINE_NODEBUG constexpr T operator&(T a, T b) noexcept { \ + return T((std::underlying_type::type)(a) & \ + (std::underlying_type::type)(b)); \ + } \ + static ASMJIT_INLINE_NODEBUG constexpr T operator^(T a, T b) noexcept { \ + return T((std::underlying_type::type)(a) ^ \ + (std::underlying_type::type)(b)); \ + } \ + \ + static ASMJIT_INLINE_NODEBUG T& operator|=(T& a, T b) noexcept { \ + a = T((std::underlying_type::type)(a) | \ + (std::underlying_type::type)(b)); \ + return a; \ + } \ + static ASMJIT_INLINE_NODEBUG T& operator&=(T& a, T b) noexcept { \ + a = T((std::underlying_type::type)(a) & \ + (std::underlying_type::type)(b)); \ + return a; \ + } \ + static ASMJIT_INLINE_NODEBUG T& operator^=(T& a, T b) noexcept { \ + a = T((std::underlying_type::type)(a) ^ \ + (std::underlying_type::type)(b)); \ + return a; \ + } +#endif + +//! \def ASMJIT_DEFINE_ENUM_COMPARE(T) +//! +//! Defines comparison operations for enumeration flags. +#if defined(_DOXYGEN) || (defined(_MSC_VER) && _MSC_VER <= 1900) + #define ASMJIT_DEFINE_ENUM_COMPARE(T) +#else + #define ASMJIT_DEFINE_ENUM_COMPARE(T) \ + static ASMJIT_INLINE_NODEBUG bool operator<(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) < (std::underlying_type::type)(b); \ + } \ + static ASMJIT_INLINE_NODEBUG bool operator<=(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) <= (std::underlying_type::type)(b); \ + } \ + static ASMJIT_INLINE_NODEBUG bool operator>(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) > (std::underlying_type::type)(b); \ + } \ + static ASMJIT_INLINE_NODEBUG bool operator>=(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) >= (std::underlying_type::type)(b); \ + } +#endif + +#endif // ASMJIT_CORE_API_CONFIG_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/archcommons.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/archcommons.h new file mode 100644 index 0000000000000000000000000000000000000000..2b47d17c1b448fcbc5cfd451937327e1ad489593 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/archcommons.h @@ -0,0 +1,261 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ARCHCOMMONS_H_INCLUDED +#define ASMJIT_CORE_ARCHCOMMONS_H_INCLUDED + +// This file provides architecture-specific classes that are required in the core library. For example Imm operand +// allows to be created from arm::Shift in a const-expr way, so the arm::Shift must be provided. So this header file +// provides everything architecture-specific that is used by the Core API. + +#include "../core/globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(arm) + +//! \addtogroup asmjit_arm +//! \{ + +//! Condition code (both AArch32 & AArch64). +//! +//! \note This enumeration doesn't match condition code that is used in AArch32/AArch64 opcodes. In general this +//! condition code is encoded as `(cc - 2) & 0xF` so that `kAL` condition code is zero and encoded as 0xE in opcode. +//! This makes it easier to use a condition code as an instruction modifier that defaults to 'al'. +enum class CondCode : uint8_t { + kAL = 0x00u, //!< (no condition code) (always) + kNA = 0x01u, //!< (not available) (special) + kEQ = 0x02u, //!< Z==1 (any_sign ==) + kNE = 0x03u, //!< Z==0 (any_sign !=) + kCS = 0x04u, //!< C==1 (unsigned >=) + kHS = 0x04u, //!< C==1 (unsigned >=) + kLO = 0x05u, //!< C==0 (unsigned < ) + kCC = 0x05u, //!< C==0 (unsigned < ) + kMI = 0x06u, //!< N==1 (is negative) + kPL = 0x07u, //!< N==0 (is positive or zero) + kVS = 0x08u, //!< V==1 (is overflow) + kVC = 0x09u, //!< V==0 (no overflow) + kHI = 0x0Au, //!< C==1 & Z==0 (unsigned > ) + kLS = 0x0Bu, //!< C==0 | Z==1 (unsigned <=) + kGE = 0x0Cu, //!< N==V (signed >=) + kLT = 0x0Du, //!< N!=V (signed < ) + kGT = 0x0Eu, //!< Z==0 & N==V (signed > ) + kLE = 0x0Fu, //!< Z==1 | N!=V (signed <=) + + kZero = kEQ, //!< Zero flag (alias to equal). + kNotZero = kNE, //!< Not zero (alias to Not Equal). + + kEqual = kEQ, //!< Equal `a == b`. + kNotEqual = kNE, //!< Not Equal `a != b`. + + kCarry = kCS, //!< Carry flag. + kNotCarry = kCC, //!< Not carry. + + kSign = kMI, //!< Sign flag. + kNotSign = kPL, //!< Not sign. + + kNegative = kMI, //!< Negative. + kPositive = kPL, //!< Positive or zero. + + kOverflow = kVS, //!< Signed overflow. + kNotOverflow = kVC, //!< Not signed overflow. + + kSignedLT = kLT, //!< Signed `a < b`. + kSignedLE = kLE, //!< Signed `a <= b`. + kSignedGT = kGT, //!< Signed `a > b`. + kSignedGE = kGE, //!< Signed `a >= b`. + + kUnsignedLT = kLO, //!< Unsigned `a < b`. + kUnsignedLE = kLS, //!< Unsigned `a <= b`. + kUnsignedGT = kHI, //!< Unsigned `a > b`. + kUnsignedGE = kHS, //!< Unsigned `a >= b`. + + kBTZero = kZero, //!< Tested bit is zero. + kBTNotZero = kNotZero, //!< Tested bit is not zero. + + kAlways = kAL, //!< No condition code (always). + + kMaxValue = 0x0Fu //!< Maximum value of `CondCode`. +}; + + +//! \cond +static constexpr CondCode _reverseCondTable[] = { + CondCode::kAL, // AL <- AL + CondCode::kNA, // NA <- NA + CondCode::kEQ, // EQ <- EQ + CondCode::kNE, // NE <- NE + CondCode::kLS, // LS <- CS + CondCode::kHI, // HI <- LO + CondCode::kMI, // MI <- MI + CondCode::kPL, // PL <- PL + CondCode::kVS, // VS <- VS + CondCode::kVC, // VC <- VC + CondCode::kLO, // LO <- HI + CondCode::kCS, // CS <- LS + CondCode::kLE, // LE <- GE + CondCode::kGT, // GT <- LT + CondCode::kLT, // LT <- GT + CondCode::kGE // GE <- LE +}; +//! \endcond + +//! Reverses a condition code (reverses the corresponding operands of a comparison). +static ASMJIT_INLINE_NODEBUG constexpr CondCode reverseCond(CondCode cond) noexcept { return _reverseCondTable[uint8_t(cond)]; } +//! Negates a condition code. +static ASMJIT_INLINE_NODEBUG constexpr CondCode negateCond(CondCode cond) noexcept { return CondCode(uint8_t(cond) ^ uint8_t(1)); } + +//! Memory offset mode. +//! +//! Describes either fixed, pre-index, or post-index offset modes. +enum class OffsetMode : uint32_t { + //! Fixed offset mode (either no index at all or a regular index without a write-back). + kFixed = 0u, + //! Pre-index "[BASE, #Offset {, }]!" with write-back. + kPreIndex = 1u, + //! Post-index "[BASE], #Offset {, }" with write-back. + kPostIndex = 2u +}; + +//! Shift operation predicate (ARM) describes either SHIFT or EXTEND operation. +//! +//! \note The constants are AsmJit specific. The first 5 values describe real constants on ARM32 and AArch64 hardware, +//! however, the addition constants that describe extend modes are specific to AsmJit and would be translated to the +//! AArch64 specific constants by the assembler. +enum class ShiftOp : uint32_t { + //! Shift left logical operation (default). + //! + //! Available to all ARM architectures. + kLSL = 0x00u, + + //! Shift right logical operation. + //! + //! Available to all ARM architectures. + kLSR = 0x01u, + + //! Shift right arithmetic operation. + //! + //! Available to all ARM architectures. + kASR = 0x02u, + + //! Rotate right operation (AArch32 only). + kROR = 0x03u, + + //! Rotate right with carry operation (encoded as `ShiftOp::kROR` with zero) (AArch32 only). + kRRX = 0x04u, + + //! Shift left by filling low order bits with ones. + kMSL = 0x05u, + + //! UXTN extend register operation (AArch64 only). + kUXTB = 0x06u, + //! UXTH extend register operation (AArch64 only). + kUXTH = 0x07u, + //! UXTW extend register operation (AArch64 only). + kUXTW = 0x08u, + //! UXTX extend register operation (AArch64 only). + kUXTX = 0x09u, + + //! SXTB extend register operation (AArch64 only). + kSXTB = 0x0Au, + //! SXTH extend register operation (AArch64 only). + kSXTH = 0x0Bu, + //! SXTW extend register operation (AArch64 only). + kSXTW = 0x0Cu, + //! SXTX extend register operation (AArch64 only). + kSXTX = 0x0Du + + // NOTE: 0xE and 0xF are used by memory operand to specify POST|PRE offset mode. +}; + +//! Represents ARM immediate shift operation type and value. +class Shift { +public: + //! Shift operation. + ShiftOp _op; + //! Shift Value. + uint32_t _value; + + //! Default constructed Shift is not initialized. + ASMJIT_INLINE_NODEBUG Shift() noexcept = default; + + //! Copy constructor (default) + ASMJIT_INLINE_NODEBUG constexpr Shift(const Shift& other) noexcept = default; + + //! Constructs Shift from operation `op` and shift `value`. + ASMJIT_INLINE_NODEBUG constexpr Shift(ShiftOp op, uint32_t value) noexcept + : _op(op), + _value(value) {} + + //! Returns the shift operation. + ASMJIT_INLINE_NODEBUG constexpr ShiftOp op() const noexcept { return _op; } + //! Sets shift operation to `op`. + ASMJIT_INLINE_NODEBUG void setOp(ShiftOp op) noexcept { _op = op; } + + //! Returns the shift amount. + ASMJIT_INLINE_NODEBUG constexpr uint32_t value() const noexcept { return _value; } + //! Sets shift amount to `value`. + ASMJIT_INLINE_NODEBUG void setValue(uint32_t value) noexcept { _value = value; } +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +ASMJIT_BEGIN_SUB_NAMESPACE(a32) + +using namespace arm; + +//! Data type that can be encoded with AArch32 instruction identifier. +//! +//! \note Data types are frequently used with AArch32 SIMD instructions. For example `VMAX` instruction can +//! use almost all datatypes in a form `VMAX.F32`, `VMAX.S16`, `VMAX.U32`, etc... Emitter automatically adds +//! the required data type at emit level. +enum class DataType : uint32_t { + //! No data type specified (default for all general purpose instructions). + kNone = 0, + //! 8-bit signed integer, specified as `.s8` in assembly. + kS8 = 1, + //! 16-bit signed integer, specified as `.s16` in assembly. + kS16 = 2, + //! 32-bit signed integer, specified as `.s32` in assembly. + kS32 = 3, + //! 64-bit signed integer, specified as `.s64` in assembly. + kS64 = 4, + //! 8-bit unsigned integer, specified as `.u8` in assembly. + kU8 = 5, + //! 16-bit unsigned integer, specified as `.u16` in assembly. + kU16 = 6, + //! 32-bit unsigned integer, specified as `.u32` in assembly. + kU32 = 7, + //! 64-bit unsigned integer, specified as `.u64` in assembly. + kU64 = 8, + //! 16-bit floating point (half precision), specified as `.f16` in assembly. + kF16 = 10, + //! 32-bit floating point (single precision), specified as `.f32` in assembly. + kF32 = 11, + //! 64-bit floating point (double precision), specified as `.f64` in assembly. + kF64 = 12, + //! 8-bit polynomial. + kP8 = 13, + //! 16-bit BF16 floating point. + kBF16 = 14, + //! 64-bit polynomial. + kP64 = 15, + + //! Maximum value of `DataType`. + kMaxValue = 15 +}; + +static ASMJIT_INLINE_NODEBUG uint32_t dataTypeSize(DataType dt) noexcept { + static constexpr uint8_t table[] = { 0, 1, 2, 4, 8, 1, 2, 4, 8, 2, 4, 8, 1, 2, 8 }; + return table[size_t(dt)]; +} + +ASMJIT_END_SUB_NAMESPACE + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) +using namespace arm; +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_CORE_ARCHCOMMONS_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/archtraits.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/archtraits.h new file mode 100644 index 0000000000000000000000000000000000000000..9f08dea523f8d72f9366d5d5e81dc70491a5d8d5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/archtraits.h @@ -0,0 +1,293 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ARCHTRAITS_H_INCLUDED +#define ASMJIT_CORE_ARCHTRAITS_H_INCLUDED + +#include "../core/operand.h" +#include "../core/support.h" +#include "../core/type.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Instruction set architecture (ISA). +enum class Arch : uint8_t { + //! Unknown or uninitialized ISA. + kUnknown = 0, + + //! 32-bit X86 ISA. + kX86 = 1, + //! 64-bit X86 ISA also known as X64, X86_64, and AMD64. + kX64 = 2, + + //! 32-bit RISC-V ISA. + kRISCV32 = 3, + //! 64-bit RISC-V ISA. + kRISCV64 = 4, + + //! 32-bit ARM ISA (little endian). + kARM = 5, + //! 64-bit ARM ISA in (little endian). + kAArch64 = 6, + //! 32-bit ARM ISA in Thumb mode (little endian). + kThumb = 7, + + // 8 is not used at the moment, even numbers are 64-bit architectures. + + //! 32-bit MIPS ISA in (little endian). + kMIPS32_LE = 9, + //! 64-bit MIPS ISA in (little endian). + kMIPS64_LE = 10, + + //! 32-bit ARM ISA (big endian). + kARM_BE = 11, + //! 64-bit ARM ISA in (big endian). + kAArch64_BE = 12, + //! 32-bit ARM ISA in Thumb mode (big endian). + kThumb_BE = 13, + + // 14 is not used at the moment, even numbers are 64-bit architectures. + + //! 32-bit MIPS ISA in (big endian). + kMIPS32_BE = 15, + //! 64-bit MIPS ISA in (big endian). + kMIPS64_BE = 16, + + //! Maximum value of `Arch`. + kMaxValue = kMIPS64_BE, + + //! Mask used by 32-bit ISAs (odd are 32-bit, even are 64-bit). + k32BitMask = 0x01, + //! First big-endian architecture. + kBigEndian = kARM_BE, + + //! ISA detected at compile-time (ISA of the host). + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#else + ASMJIT_ARCH_X86 == 32 ? kX86 : + ASMJIT_ARCH_X86 == 64 ? kX64 : + + ASMJIT_ARCH_RISCV == 32 ? kRISCV32 : + ASMJIT_ARCH_RISCV == 64 ? kRISCV64 : + + ASMJIT_ARCH_ARM == 32 && ASMJIT_ARCH_LE ? kARM : + ASMJIT_ARCH_ARM == 32 && ASMJIT_ARCH_BE ? kARM_BE : + ASMJIT_ARCH_ARM == 64 && ASMJIT_ARCH_LE ? kAArch64 : + ASMJIT_ARCH_ARM == 64 && ASMJIT_ARCH_BE ? kAArch64_BE : + + ASMJIT_ARCH_MIPS == 32 && ASMJIT_ARCH_LE ? kMIPS32_LE : + ASMJIT_ARCH_MIPS == 32 && ASMJIT_ARCH_BE ? kMIPS32_BE : + ASMJIT_ARCH_MIPS == 64 && ASMJIT_ARCH_LE ? kMIPS64_LE : + ASMJIT_ARCH_MIPS == 64 && ASMJIT_ARCH_BE ? kMIPS64_BE : + + kUnknown +#endif +}; + +//! Sub-architecture. +enum class SubArch : uint8_t { + //! Unknown or uninitialized architecture sub-type. + kUnknown = 0, + + //! Maximum value of `SubArch`. + kMaxValue = kUnknown, + + //! Sub-architecture detected at compile-time (sub-architecture of the host). + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#else + kUnknown +#endif +}; + +//! Identifier used to represent names of different data types across architectures. +enum class ArchTypeNameId : uint8_t { + //! Describes 'db' (X86/X86_64 convention, always 8-bit quantity). + kDB = 0, + //! Describes 'dw' (X86/X86_64 convention, always 16-bit word). + kDW, + //! Describes 'dd' (X86/X86_64 convention, always 32-bit word). + kDD, + //! Describes 'dq' (X86/X86_64 convention, always 64-bit word). + kDQ, + //! Describes 'byte' (always 8-bit quantity). + kByte, + //! Describes 'half' (most likely 16-bit word). + kHalf, + //! Describes 'word' (either 16-bit or 32-bit word). + kWord, + //! Describes 'hword' (most likely 16-bit word). + kHWord, + //! Describes 'dword' (either 32-bit or 64-bit word). + kDWord, + //! Describes 'qword' (64-bit word). + kQWord, + //! Describes 'xword' (64-bit word). + kXWord, + //! Describes 'short' (always 16-bit word). + kShort, + //! Describes 'long' (most likely 32-bit word). + kLong, + //! Describes 'quad' (64-bit word). + kQuad, + + //! Maximum value of `ArchTypeNameId`. + kMaxValue = kQuad +}; + +//! Instruction feature hints for each register group provided by \ref ArchTraits. +//! +//! Instruction feature hints describe miscellaneous instructions provided by the architecture that can be used by +//! register allocator to make certain things simpler - like register swaps or emitting register push/pop sequences. +//! +//! \remarks Instruction feature hints are only defined for register groups that can be used with \ref +//! asmjit_compiler infrastructure. Register groups that are not managed by Compiler are not provided by +//! \ref ArchTraits and cannot be queried. +enum class InstHints : uint8_t { + //! No feature hints. + kNoHints = 0, + + //! Architecture supports a register swap by using a single instruction. + kRegSwap = 0x01u, + //! Architecture provides push/pop instructions. + kPushPop = 0x02u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstHints) + +//! Architecture traits used by Function API and Compiler's register allocator. +struct ArchTraits { + //! \name Members + //! \{ + + //! Stack pointer register id. + uint8_t _spRegId; + //! Frame pointer register id. + uint8_t _fpRegId; + //! Link register id. + uint8_t _linkRegId; + //! Instruction pointer (or program counter) register id, if accessible. + uint8_t _ipRegId; + + // Reserved. + uint8_t _reserved[3]; + //! Hardware stack alignment requirement. + uint8_t _hwStackAlignment; + + //! Minimum addressable offset on stack guaranteed for all instructions. + uint32_t _minStackOffset; + //! Maximum addressable offset on stack depending on specific instruction. + uint32_t _maxStackOffset; + + //! Flags for each virtual register group. + Support::Array _instHints; + + //! Maps register type into a signature, that provides group, size and can be used to construct register operands. + Support::Array _regSignature; + //! Maps a register to type-id, see \ref TypeId. + Support::Array _regTypeToTypeId; + //! Maps scalar TypeId values (from TypeId::_kIdBaseStart) to register types, see \ref TypeId. + Support::Array _typeIdToRegType; + + //! Word name identifiers of 8-bit, 16-bit, 32-biit, and 64-bit quantities that appear in formatted text. + ArchTypeNameId _typeNameIdTable[4]; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns stack pointer register id. + ASMJIT_INLINE_NODEBUG uint32_t spRegId() const noexcept { return _spRegId; } + //! Returns stack frame register id. + ASMJIT_INLINE_NODEBUG uint32_t fpRegId() const noexcept { return _fpRegId; } + //! Returns link register id, if the architecture provides it. + ASMJIT_INLINE_NODEBUG uint32_t linkRegId() const noexcept { return _linkRegId; } + //! Returns instruction pointer register id, if the architecture provides it. + ASMJIT_INLINE_NODEBUG uint32_t ipRegId() const noexcept { return _ipRegId; } + + //! Returns a hardware stack alignment requirement. + //! + //! \note This is a hardware constraint. Architectures that don't constrain it would return the lowest alignment + //! (1), however, some architectures may constrain the alignment, for example AArch64 requires 16-byte alignment. + ASMJIT_INLINE_NODEBUG uint32_t hwStackAlignment() const noexcept { return _hwStackAlignment; } + + //! Tests whether the architecture provides link register, which is used across function calls. If the link + //! register is not provided then a function call pushes the return address on stack (X86/X64). + ASMJIT_INLINE_NODEBUG bool hasLinkReg() const noexcept { return _linkRegId != BaseReg::kIdBad; } + + //! Returns minimum addressable offset on stack guaranteed for all instructions. + ASMJIT_INLINE_NODEBUG uint32_t minStackOffset() const noexcept { return _minStackOffset; } + //! Returns maximum addressable offset on stack depending on specific instruction. + ASMJIT_INLINE_NODEBUG uint32_t maxStackOffset() const noexcept { return _maxStackOffset; } + + //! Returns ISA flags of the given register `group`. + ASMJIT_INLINE_NODEBUG InstHints instFeatureHints(RegGroup group) const noexcept { return _instHints[group]; } + //! Tests whether the given register `group` has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasInstHint(RegGroup group, InstHints feature) const noexcept { return Support::test(_instHints[group], feature); } + //! Tests whether the ISA provides register swap instruction for the given register `group`. + ASMJIT_INLINE_NODEBUG bool hasInstRegSwap(RegGroup group) const noexcept { return hasInstHint(group, InstHints::kRegSwap); } + //! Tests whether the ISA provides push/pop instructions for the given register `group`. + ASMJIT_INLINE_NODEBUG bool hasInstPushPop(RegGroup group) const noexcept { return hasInstHint(group, InstHints::kPushPop); } + + ASMJIT_INLINE_NODEBUG bool hasRegType(RegType type) const noexcept { + return type <= RegType::kMaxValue && _regSignature[type].isValid(); + } + + //! Returns an operand signature from the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG OperandSignature regTypeToSignature(RegType type) const noexcept { return _regSignature[type]; } + //! Returns a register from the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG RegGroup regTypeToGroup(RegType type) const noexcept { return _regSignature[type].regGroup(); } + //! Returns a register size the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG uint32_t regTypeToSize(RegType type) const noexcept { return _regSignature[type].size(); } + //! Returns a corresponding `TypeId` from the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG TypeId regTypeToTypeId(RegType type) const noexcept { return _regTypeToTypeId[type]; } + + //! Returns a table of ISA word names that appear in formatted text. Word names are ISA dependent. + //! + //! The index of this table is log2 of the size: + //! - [0] 8-bits + //! - [1] 16-bits + //! - [2] 32-bits + //! - [3] 64-bits + ASMJIT_INLINE_NODEBUG const ArchTypeNameId* typeNameIdTable() const noexcept { return _typeNameIdTable; } + + //! Returns an ISA word name identifier of the given `index`, see \ref typeNameIdTable() for more details. + ASMJIT_INLINE_NODEBUG ArchTypeNameId typeNameIdByIndex(uint32_t index) const noexcept { return _typeNameIdTable[index]; } + + //! \} + + //! \name Statics + //! \{ + + //! Returns a const reference to `ArchTraits` for the given architecture `arch`. + static ASMJIT_INLINE_NODEBUG const ArchTraits& byArch(Arch arch) noexcept; + + //! \} +}; + +ASMJIT_VARAPI const ArchTraits _archTraits[uint32_t(Arch::kMaxValue) + 1]; + +//! \cond +ASMJIT_INLINE_NODEBUG const ArchTraits& ArchTraits::byArch(Arch arch) noexcept { return _archTraits[uint32_t(arch)]; } +//! \endcond + +//! Architecture utilities. +namespace ArchUtils { + +ASMJIT_API Error typeIdToRegSignature(Arch arch, TypeId typeId, TypeId* typeIdOut, OperandSignature* regSignatureOut) noexcept; + +} // {ArchUtils} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ARCHTRAITS_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/assembler.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/assembler.h new file mode 100644 index 0000000000000000000000000000000000000000..d53d9e53764f13e58b4a43893c790da88f4fadbb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/assembler.h @@ -0,0 +1,130 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ASSEMBLER_H_INCLUDED +#define ASMJIT_CORE_ASSEMBLER_H_INCLUDED + +#include "../core/codeholder.h" +#include "../core/emitter.h" +#include "../core/operand.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_assembler +//! \{ + +//! Base assembler. +//! +//! This is a base class that provides interface used by architecture specific +//! assembler implementations. Assembler doesn't hold any data, instead it's +//! attached to \ref CodeHolder, which provides all the data that Assembler +//! needs and which can be altered by it. +//! +//! Check out architecture specific assemblers for more details and examples: +//! +//! - \ref x86::Assembler - X86/X64 assembler implementation. +//! - \ref a64::Assembler - AArch64 assembler implementation. +class ASMJIT_VIRTAPI BaseAssembler : public BaseEmitter { +public: + ASMJIT_NONCOPYABLE(BaseAssembler) + typedef BaseEmitter Base; + + //! Current section where the assembling happens. + Section* _section = nullptr; + //! Start of the CodeBuffer of the current section. + uint8_t* _bufferData = nullptr; + //! End (first invalid byte) of the current section. + uint8_t* _bufferEnd = nullptr; + //! Pointer in the CodeBuffer of the current section. + uint8_t* _bufferPtr = nullptr; + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseAssembler` instance. + ASMJIT_API BaseAssembler() noexcept; + //! Destroys the `BaseAssembler` instance. + ASMJIT_API ~BaseAssembler() noexcept override; + + //! \} + + //! \name Code-Buffer Management + //! \{ + + //! Returns the capacity of the current CodeBuffer. + ASMJIT_INLINE_NODEBUG size_t bufferCapacity() const noexcept { return (size_t)(_bufferEnd - _bufferData); } + //! Returns the number of remaining bytes in the current CodeBuffer. + ASMJIT_INLINE_NODEBUG size_t remainingSpace() const noexcept { return (size_t)(_bufferEnd - _bufferPtr); } + + //! Returns the current position in the CodeBuffer. + ASMJIT_INLINE_NODEBUG size_t offset() const noexcept { return (size_t)(_bufferPtr - _bufferData); } + + //! Sets the current position in the CodeBuffer to `offset`. + //! + //! \note The `offset` cannot be greater than buffer size even if it's + //! within the buffer's capacity. + ASMJIT_API Error setOffset(size_t offset); + + //! Returns the start of the CodeBuffer in the current section. + ASMJIT_INLINE_NODEBUG uint8_t* bufferData() const noexcept { return _bufferData; } + //! Returns the end (first invalid byte) in the current section. + ASMJIT_INLINE_NODEBUG uint8_t* bufferEnd() const noexcept { return _bufferEnd; } + //! Returns the current pointer in the CodeBuffer in the current section. + ASMJIT_INLINE_NODEBUG uint8_t* bufferPtr() const noexcept { return _bufferPtr; } + + //! \} + + //! \name Section Management + //! \{ + + //! Returns the current section. + ASMJIT_INLINE_NODEBUG Section* currentSection() const noexcept { return _section; } + + ASMJIT_API Error section(Section* section) override; + + //! \} + + //! \name Label Management + //! \{ + + ASMJIT_API Label newLabel() override; + ASMJIT_API Label newNamedLabel(const char* name, size_t nameSize = SIZE_MAX, LabelType type = LabelType::kGlobal, uint32_t parentId = Globals::kInvalidId) override; + ASMJIT_API Error bind(const Label& label) override; + + //! \} + + //! \name Embed + //! \{ + + ASMJIT_API Error embed(const void* data, size_t dataSize) override; + ASMJIT_API Error embedDataArray(TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1) override; + ASMJIT_API Error embedConstPool(const Label& label, const ConstPool& pool) override; + + ASMJIT_API Error embedLabel(const Label& label, size_t dataSize = 0) override; + ASMJIT_API Error embedLabelDelta(const Label& label, const Label& base, size_t dataSize = 0) override; + + //! \} + + //! \name Comment + //! \{ + + ASMJIT_API Error comment(const char* data, size_t size = SIZE_MAX) override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ASSEMBLER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/builder.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/builder.h new file mode 100644 index 0000000000000000000000000000000000000000..0de19234cb6ebdff4014af6bf2fd34e330503bb7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/builder.h @@ -0,0 +1,1499 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_BUILDER_H_INCLUDED +#define ASMJIT_CORE_BUILDER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_BUILDER + +#include "../core/assembler.h" +#include "../core/codeholder.h" +#include "../core/constpool.h" +#include "../core/formatter.h" +#include "../core/inst.h" +#include "../core/operand.h" +#include "../core/string.h" +#include "../core/support.h" +#include "../core/type.h" +#include "../core/zone.h" +#include "../core/zonevector.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_builder +//! \{ + +class BaseBuilder; +class Pass; + +class BaseNode; +class InstNode; +class SectionNode; +class LabelNode; +class AlignNode; +class EmbedDataNode; +class EmbedLabelNode; +class ConstPoolNode; +class CommentNode; +class SentinelNode; +class LabelDeltaNode; + +//! Type of node used by \ref BaseBuilder and \ref BaseCompiler. +enum class NodeType : uint8_t { + //! Invalid node (internal, don't use). + kNone = 0, + + // [BaseBuilder] + + //! Node is \ref InstNode. + kInst = 1, + //! Node is \ref SectionNode. + kSection = 2, + //! Node is \ref LabelNode. + kLabel = 3, + //! Node is \ref AlignNode. + kAlign = 4, + //! Node is \ref EmbedDataNode. + kEmbedData = 5, + //! Node is \ref EmbedLabelNode. + kEmbedLabel = 6, + //! Node is \ref EmbedLabelDeltaNode. + kEmbedLabelDelta = 7, + //! Node is \ref ConstPoolNode. + kConstPool = 8, + //! Node is \ref CommentNode. + kComment = 9, + //! Node is \ref SentinelNode. + kSentinel = 10, + + // [BaseCompiler] + + //! Node is \ref JumpNode (acts as InstNode). + kJump = 15, + //! Node is \ref FuncNode (acts as LabelNode). + kFunc = 16, + //! Node is \ref FuncRetNode (acts as InstNode). + kFuncRet = 17, + //! Node is \ref InvokeNode (acts as InstNode). + kInvoke = 18, + + // [UserDefined] + + //! First id of a user-defined node. + kUser = 32 +}; + +//! Node flags, specify what the node is and/or does. +enum class NodeFlags : uint8_t { + //! No flags. + kNone = 0, + //! Node is code that can be executed (instruction, label, align, etc...). + kIsCode = 0x01u, + //! Node is data that cannot be executed (data, const-pool, etc...). + kIsData = 0x02u, + //! Node is informative, can be removed and ignored. + kIsInformative = 0x04u, + //! Node can be safely removed if unreachable. + kIsRemovable = 0x08u, + //! Node does nothing when executed (label, align, explicit nop). + kHasNoEffect = 0x10u, + //! Node is an instruction or acts as it. + kActsAsInst = 0x20u, + //! Node is a label or acts as it. + kActsAsLabel = 0x40u, + //! Node is active (part of the code). + kIsActive = 0x80u +}; +ASMJIT_DEFINE_ENUM_FLAGS(NodeFlags) + +//! Type of the sentinel (purely informative purpose). +enum class SentinelType : uint8_t { + //! Type of the sentinel is not known. + kUnknown = 0u, + //! This is a sentinel used at the end of \ref FuncNode. + kFuncEnd = 1u +}; + +//! Node list. +//! +//! A double-linked list of pointers to \ref BaseNode, managed by \ref BaseBuilder or \ref BaseCompiler. +//! +//! \note At the moment NodeList is just a view, but it's planned that it will get more functionality in the future. +class NodeList { +public: + //! \name Members + //! \{ + + //! First node in the list or nullptr if there are no nodes in the list. + BaseNode* _first = nullptr; + //! Last node in the list or nullptr if there are no nodes in the list. + BaseNode* _last = nullptr; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG NodeList() noexcept {} + + ASMJIT_INLINE_NODEBUG NodeList(BaseNode* first, BaseNode* last) noexcept + : _first(first), + _last(last) {} + + //! \} + + //! \name Reset + //! \{ + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _first = nullptr; + _last = nullptr; + } + + ASMJIT_INLINE_NODEBUG void reset(BaseNode* first, BaseNode* last) noexcept { + _first = first; + _last = last; + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _first == nullptr; } + + ASMJIT_INLINE_NODEBUG BaseNode* first() const noexcept { return _first; } + ASMJIT_INLINE_NODEBUG BaseNode* last() const noexcept { return _last; } + + //! \} +}; + +//! Builder interface. +//! +//! `BaseBuilder` interface was designed to be used as a \ref BaseAssembler replacement in case pre-processing or +//! post-processing of the generated code is required. The code can be modified during or after code generation. +//! Pre processing or post processing can be done manually or through a \ref Pass object. \ref BaseBuilder stores +//! the emitted code as a double-linked list of nodes, which allows O(1) insertion and removal during processing. +//! +//! Check out architecture specific builders for more details and examples: +//! +//! - \ref x86::Builder - X86/X64 builder implementation. +//! - \ref a64::Builder - AArch64 builder implementation. +class ASMJIT_VIRTAPI BaseBuilder : public BaseEmitter { +public: + ASMJIT_NONCOPYABLE(BaseBuilder) + typedef BaseEmitter Base; + + //! \name Members + //! \{ + + //! Base zone used to allocate nodes and passes. + Zone _codeZone; + //! Data zone used to allocate data and names. + Zone _dataZone; + //! Pass zone, passed to `Pass::run()`. + Zone _passZone; + //! Allocator that uses `_codeZone`. + ZoneAllocator _allocator; + + //! Array of `Pass` objects. + ZoneVector _passes {}; + //! Maps section indexes to `LabelNode` nodes. + ZoneVector _sectionNodes {}; + //! Maps label indexes to `LabelNode` nodes. + ZoneVector _labelNodes {}; + + //! Current node (cursor). + BaseNode* _cursor = nullptr; + //! First and last nodes. + NodeList _nodeList; + + //! Flags assigned to each new node. + NodeFlags _nodeFlags = NodeFlags::kNone; + //! The sections links are dirty (used internally). + bool _dirtySectionLinks = false; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseBuilder` instance. + ASMJIT_API BaseBuilder() noexcept; + //! Destroys the `BaseBuilder` instance. + ASMJIT_API ~BaseBuilder() noexcept override; + + //! \} + + //! \name Node Management + //! \{ + + ASMJIT_INLINE_NODEBUG NodeList nodeList() const noexcept { return _nodeList; } + + //! Returns the first node. + ASMJIT_INLINE_NODEBUG BaseNode* firstNode() const noexcept { return _nodeList.first(); } + //! Returns the last node. + ASMJIT_INLINE_NODEBUG BaseNode* lastNode() const noexcept { return _nodeList.last(); } + + //! Allocates and instantiates a new node of type `T` and returns its instance. If the allocation fails `nullptr` + //! is returned. + //! + //! The template argument `T` must be a type that is extends \ref BaseNode. + //! + //! \remarks The pointer returned (if non-null) is owned by the Builder or Compiler. When the Builder/Compiler + //! is destroyed it destroys all nodes it created so no manual memory management is required. + template + inline Error _newNodeT(T** ASMJIT_NONNULL(out), Args&&... args) { + *out = _allocator.newT(this, std::forward(args)...); + if (ASMJIT_UNLIKELY(!*out)) + return reportError(DebugUtils::errored(kErrorOutOfMemory)); + return kErrorOk; + } + + //! Creates a new \ref InstNode. + ASMJIT_API Error newInstNode(InstNode** ASMJIT_NONNULL(out), InstId instId, InstOptions instOptions, uint32_t opCount); + //! Creates a new \ref LabelNode. + ASMJIT_API Error newLabelNode(LabelNode** ASMJIT_NONNULL(out)); + //! Creates a new \ref AlignNode. + ASMJIT_API Error newAlignNode(AlignNode** ASMJIT_NONNULL(out), AlignMode alignMode, uint32_t alignment); + //! Creates a new \ref EmbedDataNode. + ASMJIT_API Error newEmbedDataNode(EmbedDataNode** ASMJIT_NONNULL(out), TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1); + //! Creates a new \ref ConstPoolNode. + ASMJIT_API Error newConstPoolNode(ConstPoolNode** ASMJIT_NONNULL(out)); + //! Creates a new \ref CommentNode. + ASMJIT_API Error newCommentNode(CommentNode** ASMJIT_NONNULL(out), const char* data, size_t size); + + //! Adds `node` after the current and sets the current node to the given `node`. + ASMJIT_API BaseNode* addNode(BaseNode* ASMJIT_NONNULL(node)) noexcept; + //! Inserts the given `node` after `ref`. + ASMJIT_API BaseNode* addAfter(BaseNode* ASMJIT_NONNULL(node), BaseNode* ASMJIT_NONNULL(ref)) noexcept; + //! Inserts the given `node` before `ref`. + ASMJIT_API BaseNode* addBefore(BaseNode* ASMJIT_NONNULL(node), BaseNode* ASMJIT_NONNULL(ref)) noexcept; + //! Removes the given `node`. + ASMJIT_API BaseNode* removeNode(BaseNode* ASMJIT_NONNULL(node)) noexcept; + //! Removes multiple nodes. + ASMJIT_API void removeNodes(BaseNode* first, BaseNode* last) noexcept; + + //! Returns the cursor. + //! + //! When the Builder/Compiler is created it automatically creates a '.text' \ref SectionNode, which will be the + //! initial one. When instructions are added they are always added after the cursor and the cursor is changed + //! to be that newly added node. Use `setCursor()` to change where new nodes are inserted. + ASMJIT_INLINE_NODEBUG BaseNode* cursor() const noexcept { return _cursor; } + + //! Sets the current node to `node` and return the previous one. + ASMJIT_API BaseNode* setCursor(BaseNode* node) noexcept; + + //! Sets the current node without returning the previous node. + //! + //! Only use this function if you are concerned about performance and want this inlined (for example if you set + //! the cursor in a loop, etc...). + ASMJIT_INLINE_NODEBUG void _setCursor(BaseNode* node) noexcept { _cursor = node; } + + //! \} + + //! \name Section Management + //! \{ + + //! Returns a vector of SectionNode objects. + //! + //! \note If a section of some id is not associated with the Builder/Compiler it would be null, so always check + //! for nulls if you iterate over the vector. + ASMJIT_INLINE_NODEBUG const ZoneVector& sectionNodes() const noexcept { + return _sectionNodes; + } + + //! Tests whether the `SectionNode` of the given `sectionId` was registered. + ASMJIT_INLINE_NODEBUG bool hasRegisteredSectionNode(uint32_t sectionId) const noexcept { + return sectionId < _sectionNodes.size() && _sectionNodes[sectionId] != nullptr; + } + + //! Returns or creates a `SectionNode` that matches the given `sectionId`. + //! + //! \remarks This function will either get the existing `SectionNode` or create it in case it wasn't created before. + //! You can check whether a section has a registered `SectionNode` by using `BaseBuilder::hasRegisteredSectionNode()`. + ASMJIT_API Error sectionNodeOf(SectionNode** ASMJIT_NONNULL(out), uint32_t sectionId); + + ASMJIT_API Error section(Section* ASMJIT_NONNULL(section)) override; + + //! Returns whether the section links of active section nodes are dirty. You can update these links by calling + //! `updateSectionLinks()` in such case. + ASMJIT_INLINE_NODEBUG bool hasDirtySectionLinks() const noexcept { return _dirtySectionLinks; } + + //! Updates links of all active section nodes. + ASMJIT_API void updateSectionLinks() noexcept; + + //! \} + + //! \name Label Management + //! \{ + + //! Returns a vector of \ref LabelNode nodes. + //! + //! \note If a label of some id is not associated with the Builder/Compiler it would be null, so always check for + //! nulls if you iterate over the vector. + ASMJIT_INLINE_NODEBUG const ZoneVector& labelNodes() const noexcept { return _labelNodes; } + + //! Tests whether the `LabelNode` of the given `labelId` was registered. + ASMJIT_INLINE_NODEBUG bool hasRegisteredLabelNode(uint32_t labelId) const noexcept { + return labelId < _labelNodes.size() && _labelNodes[labelId] != nullptr; + } + + //! \overload + ASMJIT_INLINE_NODEBUG bool hasRegisteredLabelNode(const Label& label) const noexcept { + return hasRegisteredLabelNode(label.id()); + } + + //! Gets or creates a \ref LabelNode that matches the given `labelId`. + //! + //! \remarks This function will either get the existing `LabelNode` or create it in case it wasn't created before. + //! You can check whether a label has a registered `LabelNode` by calling \ref BaseBuilder::hasRegisteredLabelNode(). + ASMJIT_API Error labelNodeOf(LabelNode** ASMJIT_NONNULL(out), uint32_t labelId); + + //! \overload + ASMJIT_INLINE_NODEBUG Error labelNodeOf(LabelNode** ASMJIT_NONNULL(out), const Label& label) { + return labelNodeOf(out, label.id()); + } + + //! Registers this \ref LabelNode (internal). + //! + //! This function is used internally to register a newly created `LabelNode` with this instance of Builder/Compiler. + //! Use \ref labelNodeOf() functions to get back \ref LabelNode from a label or its identifier. + ASMJIT_API Error registerLabelNode(LabelNode* ASMJIT_NONNULL(node)); + + ASMJIT_API Label newLabel() override; + ASMJIT_API Label newNamedLabel(const char* name, size_t nameSize = SIZE_MAX, LabelType type = LabelType::kGlobal, uint32_t parentId = Globals::kInvalidId) override; + ASMJIT_API Error bind(const Label& label) override; + + //! \} + + //! \name Passes + //! \{ + + //! Returns a vector of `Pass` instances that will be executed by `runPasses()`. + ASMJIT_INLINE_NODEBUG const ZoneVector& passes() const noexcept { return _passes; } + + //! Allocates and instantiates a new pass of type `T` and returns its instance. If the allocation fails `nullptr` is + //! returned. + //! + //! The template argument `T` must be a type that is extends \ref Pass. + //! + //! \remarks The pointer returned (if non-null) is owned by the Builder or Compiler. When the Builder/Compiler is + //! destroyed it destroys all passes it created so no manual memory management is required. + template + inline T* newPassT() noexcept { return _codeZone.newT(); } + + //! \overload + template + inline T* newPassT(Args&&... args) noexcept { return _codeZone.newT(std::forward(args)...); } + + template + inline Error addPassT() { return addPass(newPassT()); } + + template + inline Error addPassT(Args&&... args) { return addPass(newPassT(std::forward(args)...)); } + + //! Returns `Pass` by name. + //! + //! If the pass having the given `name` doesn't exist `nullptr` is returned. + ASMJIT_API Pass* passByName(const char* name) const noexcept; + //! Adds `pass` to the list of passes. + ASMJIT_API Error addPass(Pass* pass) noexcept; + //! Removes `pass` from the list of passes and delete it. + ASMJIT_API Error deletePass(Pass* pass) noexcept; + + //! Runs all passes in order. + ASMJIT_API Error runPasses(); + + //! \} + + //! \name Emit + //! \{ + + ASMJIT_API Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* opExt) override; + + //! \} + + //! \name Align + //! \{ + + ASMJIT_API Error align(AlignMode alignMode, uint32_t alignment) override; + + //! \} + + //! \name Embed + //! \{ + + ASMJIT_API Error embed(const void* data, size_t dataSize) override; + ASMJIT_API Error embedDataArray(TypeId typeId, const void* data, size_t count, size_t repeat = 1) override; + ASMJIT_API Error embedConstPool(const Label& label, const ConstPool& pool) override; + + ASMJIT_API Error embedLabel(const Label& label, size_t dataSize = 0) override; + ASMJIT_API Error embedLabelDelta(const Label& label, const Label& base, size_t dataSize = 0) override; + + //! \} + + //! \name Comment + //! \{ + + ASMJIT_API Error comment(const char* data, size_t size = SIZE_MAX) override; + + //! \} + + //! \name Serialization + //! \{ + + //! Serializes everything the given emitter `dst`. + //! + //! Although not explicitly required the emitter will most probably be of Assembler type. The reason is that + //! there is no known use of serializing nodes held by Builder/Compiler into another Builder-like emitter. + ASMJIT_API Error serializeTo(BaseEmitter* dst); + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! Base node. +//! +//! Every node represents a building-block used by \ref BaseBuilder. It can be instruction, data, label, comment, +//! directive, or any other high-level representation that can be transformed to the building blocks mentioned. +//! Every class that inherits \ref BaseBuilder can define its own high-level nodes that can be later lowered to +//! basic nodes like instructions. +class BaseNode { +public: + ASMJIT_NONCOPYABLE(BaseNode) + + //! \name Members + //! \{ + + union { + struct { + //! Previous node. + BaseNode* _prev; + //! Next node. + BaseNode* _next; + }; + //! Links (an alternative view to previous and next nodes). + BaseNode* _links[2]; + }; + + //! Data shared between all types of nodes. + struct AnyData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Not used by BaseNode. + uint8_t _reserved0; + //! Not used by BaseNode. + uint8_t _reserved1; + }; + + //! Data used by \ref AlignNode. + struct AlignData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Align mode. + AlignMode _alignMode; + //! Not used by AlignNode. + uint8_t _reserved; + }; + + //! Data used by \ref InstNode. + struct InstData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Instruction operands count (used). + uint8_t _opCount; + //! Instruction operands capacity (allocated). + uint8_t _opCapacity; + }; + + //! Data used by \ref EmbedDataNode. + struct EmbedData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Type id. + TypeId _typeId; + //! Size of `_typeId`. + uint8_t _typeSize; + }; + + //! Data used by \ref SentinelNode. + struct SentinelData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Sentinel type. + SentinelType _sentinelType; + //! Not used by BaseNode. + uint8_t _reserved1; + }; + + //! Data that can have different meaning depending on \ref NodeType. + union { + //! Data useful by any node type. + AnyData _any; + //! Data specific to \ref AlignNode. + AlignData _alignData; + //! Data specific to \ref InstNode. + InstData _inst; + //! Data specific to \ref EmbedDataNode. + EmbedData _embed; + //! Data specific to \ref SentinelNode. + SentinelData _sentinel; + }; + + //! Node position in code (should be unique). + uint32_t _position; + + //! Value reserved for AsmJit users never touched by AsmJit itself. + union { + //! User data as 64-bit integer. + uint64_t _userDataU64; + //! User data as pointer. + void* _userDataPtr; + }; + + //! Data used exclusively by the current `Pass`. + void* _passData; + + //! Inline comment/annotation or nullptr if not used. + const char* _inlineComment; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseNode` - always use `BaseBuilder` to allocate nodes. + ASMJIT_INLINE_NODEBUG BaseNode(BaseBuilder* cb, NodeType nodeType, NodeFlags nodeFlags = NodeFlags::kNone) noexcept { + _prev = nullptr; + _next = nullptr; + _any._nodeType = nodeType; + _any._nodeFlags = nodeFlags | cb->_nodeFlags; + _any._reserved0 = 0; + _any._reserved1 = 0; + _position = 0; + _userDataU64 = 0; + _passData = nullptr; + _inlineComment = nullptr; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Casts this node to `T*`. + template + ASMJIT_INLINE_NODEBUG T* as() noexcept { return static_cast(this); } + //! Casts this node to `const T*`. + template + ASMJIT_INLINE_NODEBUG const T* as() const noexcept { return static_cast(this); } + + //! Returns previous node or `nullptr` if this node is either first or not + //! part of Builder/Compiler node-list. + ASMJIT_INLINE_NODEBUG BaseNode* prev() const noexcept { return _prev; } + //! Returns next node or `nullptr` if this node is either last or not part + //! of Builder/Compiler node-list. + ASMJIT_INLINE_NODEBUG BaseNode* next() const noexcept { return _next; } + + //! Returns the type of the node, see `NodeType`. + ASMJIT_INLINE_NODEBUG NodeType type() const noexcept { return _any._nodeType; } + + //! Sets the type of the node, see `NodeType` (internal). + //! + //! \remarks You should never set a type of a node to anything else than the initial value. This function is only + //! provided for users that use custom nodes and need to change the type either during construction or later. + ASMJIT_INLINE_NODEBUG void setType(NodeType type) noexcept { _any._nodeType = type; } + + //! Tests whether this node is either `InstNode` or extends it. + ASMJIT_INLINE_NODEBUG bool isInst() const noexcept { return hasFlag(NodeFlags::kActsAsInst); } + //! Tests whether this node is `SectionNode`. + ASMJIT_INLINE_NODEBUG bool isSection() const noexcept { return type() == NodeType::kSection; } + //! Tests whether this node is either `LabelNode` or extends it. + ASMJIT_INLINE_NODEBUG bool isLabel() const noexcept { return hasFlag(NodeFlags::kActsAsLabel); } + //! Tests whether this node is `AlignNode`. + ASMJIT_INLINE_NODEBUG bool isAlign() const noexcept { return type() == NodeType::kAlign; } + //! Tests whether this node is `EmbedDataNode`. + ASMJIT_INLINE_NODEBUG bool isEmbedData() const noexcept { return type() == NodeType::kEmbedData; } + //! Tests whether this node is `EmbedLabelNode`. + ASMJIT_INLINE_NODEBUG bool isEmbedLabel() const noexcept { return type() == NodeType::kEmbedLabel; } + //! Tests whether this node is `EmbedLabelDeltaNode`. + ASMJIT_INLINE_NODEBUG bool isEmbedLabelDelta() const noexcept { return type() == NodeType::kEmbedLabelDelta; } + //! Tests whether this node is `ConstPoolNode`. + ASMJIT_INLINE_NODEBUG bool isConstPool() const noexcept { return type() == NodeType::kConstPool; } + //! Tests whether this node is `CommentNode`. + ASMJIT_INLINE_NODEBUG bool isComment() const noexcept { return type() == NodeType::kComment; } + //! Tests whether this node is `SentinelNode`. + ASMJIT_INLINE_NODEBUG bool isSentinel() const noexcept { return type() == NodeType::kSentinel; } + + //! Tests whether this node is `FuncNode`. + ASMJIT_INLINE_NODEBUG bool isFunc() const noexcept { return type() == NodeType::kFunc; } + //! Tests whether this node is `FuncRetNode`. + ASMJIT_INLINE_NODEBUG bool isFuncRet() const noexcept { return type() == NodeType::kFuncRet; } + //! Tests whether this node is `InvokeNode`. + ASMJIT_INLINE_NODEBUG bool isInvoke() const noexcept { return type() == NodeType::kInvoke; } + + //! Returns the node flags. + ASMJIT_INLINE_NODEBUG NodeFlags flags() const noexcept { return _any._nodeFlags; } + //! Tests whether the node has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(NodeFlags flag) const noexcept { return Support::test(_any._nodeFlags, flag); } + //! Replaces node flags with `flags`. + ASMJIT_INLINE_NODEBUG void setFlags(NodeFlags flags) noexcept { _any._nodeFlags = flags; } + //! Adds the given `flags` to node flags. + ASMJIT_INLINE_NODEBUG void addFlags(NodeFlags flags) noexcept { _any._nodeFlags |= flags; } + //! Clears the given `flags` from node flags. + ASMJIT_INLINE_NODEBUG void clearFlags(NodeFlags flags) noexcept { _any._nodeFlags &= ~flags; } + + //! Tests whether the node is code that can be executed. + ASMJIT_INLINE_NODEBUG bool isCode() const noexcept { return hasFlag(NodeFlags::kIsCode); } + //! Tests whether the node is data that cannot be executed. + ASMJIT_INLINE_NODEBUG bool isData() const noexcept { return hasFlag(NodeFlags::kIsData); } + //! Tests whether the node is informative only (is never encoded like comment, etc...). + ASMJIT_INLINE_NODEBUG bool isInformative() const noexcept { return hasFlag(NodeFlags::kIsInformative); } + //! Tests whether the node is removable if it's in an unreachable code block. + ASMJIT_INLINE_NODEBUG bool isRemovable() const noexcept { return hasFlag(NodeFlags::kIsRemovable); } + //! Tests whether the node has no effect when executed (label, .align, nop, ...). + ASMJIT_INLINE_NODEBUG bool hasNoEffect() const noexcept { return hasFlag(NodeFlags::kHasNoEffect); } + //! Tests whether the node is part of the code. + ASMJIT_INLINE_NODEBUG bool isActive() const noexcept { return hasFlag(NodeFlags::kIsActive); } + + //! Tests whether the node has a position assigned. + //! + //! \remarks Returns `true` if node position is non-zero. + ASMJIT_INLINE_NODEBUG bool hasPosition() const noexcept { return _position != 0; } + //! Returns node position. + ASMJIT_INLINE_NODEBUG uint32_t position() const noexcept { return _position; } + //! Sets node position. + //! + //! Node position is a 32-bit unsigned integer that is used by Compiler to track where the node is relatively to + //! the start of the function. It doesn't describe a byte position in a binary, instead it's just a pseudo position + //! used by liveness analysis and other tools around Compiler. + //! + //! If you don't use Compiler then you may use `position()` and `setPosition()` freely for your own purposes if + //! the 32-bit value limit is okay for you. + ASMJIT_INLINE_NODEBUG void setPosition(uint32_t position) noexcept { _position = position; } + + //! Returns user data casted to `T*`. + //! + //! User data is dedicated to be used only by AsmJit users and not touched by the library. The data is of a pointer + //! size so you can either store a pointer or `int64_t` value through `setUserDataAsPtr()`, `setUserDataAsInt64()` + //! and `setUserDataAsUInt64()`. + template + ASMJIT_INLINE_NODEBUG T* userDataAsPtr() const noexcept { return static_cast(_userDataPtr); } + //! Returns user data casted to `int64_t`. + ASMJIT_INLINE_NODEBUG int64_t userDataAsInt64() const noexcept { return int64_t(_userDataU64); } + //! Returns user data casted to `uint64_t`. + ASMJIT_INLINE_NODEBUG uint64_t userDataAsUInt64() const noexcept { return _userDataU64; } + + //! Sets user data to `data`. + template + ASMJIT_INLINE_NODEBUG void setUserDataAsPtr(T* data) noexcept { _userDataPtr = static_cast(data); } + //! Sets used data to the given 64-bit signed `value`. + ASMJIT_INLINE_NODEBUG void setUserDataAsInt64(int64_t value) noexcept { _userDataU64 = uint64_t(value); } + //! Sets used data to the given 64-bit unsigned `value`. + ASMJIT_INLINE_NODEBUG void setUserDataAsUInt64(uint64_t value) noexcept { _userDataU64 = value; } + + //! Resets user data to zero / nullptr. + ASMJIT_INLINE_NODEBUG void resetUserData() noexcept { _userDataU64 = 0; } + + //! Tests whether the node has an associated pass data. + ASMJIT_INLINE_NODEBUG bool hasPassData() const noexcept { return _passData != nullptr; } + //! Returns the node pass data - data used during processing & transformations. + template + ASMJIT_INLINE_NODEBUG T* passData() const noexcept { return (T*)_passData; } + //! Sets the node pass data to `data`. + template + ASMJIT_INLINE_NODEBUG void setPassData(T* data) noexcept { _passData = (void*)data; } + //! Resets the node pass data to nullptr. + ASMJIT_INLINE_NODEBUG void resetPassData() noexcept { _passData = nullptr; } + + //! Tests whether the node has an inline comment/annotation. + ASMJIT_INLINE_NODEBUG bool hasInlineComment() const noexcept { return _inlineComment != nullptr; } + //! Returns an inline comment/annotation string. + ASMJIT_INLINE_NODEBUG const char* inlineComment() const noexcept { return _inlineComment; } + //! Sets an inline comment/annotation string to `s`. + ASMJIT_INLINE_NODEBUG void setInlineComment(const char* s) noexcept { _inlineComment = s; } + //! Resets an inline comment/annotation string to nullptr. + ASMJIT_INLINE_NODEBUG void resetInlineComment() noexcept { _inlineComment = nullptr; } + + //! \} +}; + +//! Instruction node. +//! +//! Wraps an instruction with its options and operands. +class InstNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(InstNode) + + //! \name Constants + //! \{ + + //! The number of embedded operands for a default \ref InstNode instance that are always allocated as a part of + //! the instruction itself. Minimum embedded operands is 4, but in 32-bit more pointers are smaller and we can + //! embed 5. The rest (up to 6 operands) is considered extended. + //! + //! The number of operands InstNode holds is decided when \ref InstNode is created. + static constexpr uint32_t kBaseOpCapacity = uint32_t((128 - sizeof(BaseNode) - sizeof(BaseInst)) / sizeof(Operand_)); + + //! Count of maximum number of operands \ref InstNode can hold. + static constexpr uint32_t kFullOpCapacity = Globals::kMaxOpCount; + + //! \} + + //! \name Members + //! \{ + + //! Base instruction data. + BaseInst _baseInst; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `InstNode` instance. + ASMJIT_INLINE_NODEBUG InstNode(BaseBuilder* cb, InstId instId, InstOptions options, uint32_t opCount, uint32_t opCapacity = kBaseOpCapacity) noexcept + : BaseNode(cb, NodeType::kInst, NodeFlags::kIsCode | NodeFlags::kIsRemovable | NodeFlags::kActsAsInst), + _baseInst(instId, options) { + _inst._opCapacity = uint8_t(opCapacity); + _inst._opCount = uint8_t(opCount); + } + + //! \cond INTERNAL + //! Reset all built-in operands, including `extraReg`. + ASMJIT_INLINE_NODEBUG void _resetOps() noexcept { + _baseInst.resetExtraReg(); + resetOpRange(0, opCapacity()); + } + //! \endcond + + //! \} + + //! \name Instruction Object + //! \{ + + ASMJIT_INLINE_NODEBUG BaseInst& baseInst() noexcept { return _baseInst; } + ASMJIT_INLINE_NODEBUG const BaseInst& baseInst() const noexcept { return _baseInst; } + + //! \} + + //! \name Instruction Id + //! \{ + + //! Returns the instruction id, see `BaseInst::Id`. + ASMJIT_INLINE_NODEBUG InstId id() const noexcept { return _baseInst.id(); } + //! Returns the instruction real id, see `BaseInst::Id`. + ASMJIT_INLINE_NODEBUG InstId realId() const noexcept { return _baseInst.realId(); } + + //! Sets the instruction id to `id`, see `BaseInst::Id`. + ASMJIT_INLINE_NODEBUG void setId(InstId id) noexcept { _baseInst.setId(id); } + + //! \} + + //! \name Instruction Options + //! \{ + + //! Returns instruction options, see \ref InstOptions for more details. + ASMJIT_INLINE_NODEBUG InstOptions options() const noexcept { return _baseInst.options(); } + //! Tests whether instruction has the given \option` set/enabled. + ASMJIT_INLINE_NODEBUG bool hasOption(InstOptions option) const noexcept { return _baseInst.hasOption(option); } + //! Sets instruction `options` to the provided value, resetting all others. + ASMJIT_INLINE_NODEBUG void setOptions(InstOptions options) noexcept { _baseInst.setOptions(options); } + //! Adds instruction `options` to the instruction. + ASMJIT_INLINE_NODEBUG void addOptions(InstOptions options) noexcept { _baseInst.addOptions(options); } + //! Clears instruction `options` of the instruction (disables the given options). + ASMJIT_INLINE_NODEBUG void clearOptions(InstOptions options) noexcept { _baseInst.clearOptions(options); } + //! Resets instruction options to none - disabling all instruction options. + ASMJIT_INLINE_NODEBUG void resetOptions() noexcept { _baseInst.resetOptions(); } + + //! \} + + //! \name Extra Register + //! \{ + + //! Tests whether the node has an extra register operand. + ASMJIT_INLINE_NODEBUG bool hasExtraReg() const noexcept { return _baseInst.hasExtraReg(); } + //! Returns extra register operand. + ASMJIT_INLINE_NODEBUG RegOnly& extraReg() noexcept { return _baseInst.extraReg(); } + //! \overload + ASMJIT_INLINE_NODEBUG const RegOnly& extraReg() const noexcept { return _baseInst.extraReg(); } + //! Sets extra register operand to `reg`. + ASMJIT_INLINE_NODEBUG void setExtraReg(const BaseReg& reg) noexcept { _baseInst.setExtraReg(reg); } + //! Sets extra register operand to `reg`. + ASMJIT_INLINE_NODEBUG void setExtraReg(const RegOnly& reg) noexcept { _baseInst.setExtraReg(reg); } + //! Resets extra register operand. + ASMJIT_INLINE_NODEBUG void resetExtraReg() noexcept { _baseInst.resetExtraReg(); } + + //! \} + + //! \name Instruction Operands + //! \{ + + //! Returns operand count. + ASMJIT_INLINE_NODEBUG uint32_t opCount() const noexcept { return _inst._opCount; } + //! Returns operand capacity. + ASMJIT_INLINE_NODEBUG uint32_t opCapacity() const noexcept { return _inst._opCapacity; } + + //! Sets operand count. + ASMJIT_INLINE_NODEBUG void setOpCount(uint32_t opCount) noexcept { _inst._opCount = uint8_t(opCount); } + + //! Returns operands array. + ASMJIT_INLINE_NODEBUG Operand* operands() noexcept { + return reinterpret_cast(reinterpret_cast(this) + sizeof(InstNode)); + } + + //! Returns operands array (const). + ASMJIT_INLINE_NODEBUG const Operand* operands() const noexcept { + return reinterpret_cast(reinterpret_cast(this) + sizeof(InstNode)); + } + + //! Returns operand at the given `index`. + inline Operand& op(uint32_t index) noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + Operand* ops = operands(); + return ops[index].as(); + } + + //! Returns operand at the given `index` (const). + inline const Operand& op(uint32_t index) const noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + const Operand* ops = operands(); + return ops[index].as(); + } + + //! Sets operand at the given `index` to `op`. + inline void setOp(uint32_t index, const Operand_& op) noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + Operand* ops = operands(); + ops[index].copyFrom(op); + } + + //! Resets operand at the given `index` to none. + inline void resetOp(uint32_t index) noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + Operand* ops = operands(); + ops[index].reset(); + } + + //! Resets operands at `[start, end)` range. + inline void resetOpRange(uint32_t start, uint32_t end) noexcept { + Operand* ops = operands(); + for (uint32_t i = start; i < end; i++) + ops[i].reset(); + } + + //! \} + + //! \name Utilities + //! \{ + + //! Tests whether the given operand type `opType` is used by the instruction. + inline bool hasOpType(OperandType opType) const noexcept { + const Operand* ops = operands(); + for (uint32_t i = 0, count = opCount(); i < count; i++) + if (ops[i].opType() == opType) + return true; + return false; + } + + //! Tests whether the instruction uses at least one register operand. + inline bool hasRegOp() const noexcept { return hasOpType(OperandType::kReg); } + //! Tests whether the instruction uses at least one memory operand. + inline bool hasMemOp() const noexcept { return hasOpType(OperandType::kMem); } + //! Tests whether the instruction uses at least one immediate operand. + inline bool hasImmOp() const noexcept { return hasOpType(OperandType::kImm); } + //! Tests whether the instruction uses at least one label operand. + inline bool hasLabelOp() const noexcept { return hasOpType(OperandType::kLabel); } + + //! Returns the index of the given operand type `opType`. + //! + //! \note If the operand type wa found, the value returned represents its index in \ref operands() + //! array, otherwise \ref Globals::kNotFound is returned to signalize that the operand was not found. + inline uint32_t indexOfOpType(OperandType opType) const noexcept { + uint32_t i = 0; + uint32_t count = opCount(); + const Operand* ops = operands(); + + while (i < count) { + if (ops[i].opType() == opType) + return i; + i++; + } + + return Globals::kNotFound; + } + + //! A shortcut that calls `indexOfOpType(OperandType::kMem)`. + inline uint32_t indexOfMemOp() const noexcept { return indexOfOpType(OperandType::kMem); } + //! A shortcut that calls `indexOfOpType(OperandType::kImm)`. + inline uint32_t indexOfImmOp() const noexcept { return indexOfOpType(OperandType::kImm); } + //! A shortcut that calls `indexOfOpType(OperandType::kLabel)`. + inline uint32_t indexOfLabelOp() const noexcept { return indexOfOpType(OperandType::kLabel); } + + //! \} + + //! \name Rewriting + //! \{ + + //! \cond INTERNAL + + //! Returns uint32_t[] view that represents BaseInst::RegOnly and instruction operands. + ASMJIT_INLINE_NODEBUG uint32_t* _getRewriteArray() noexcept { return &_baseInst._extraReg._id; } + //! \overload + ASMJIT_INLINE_NODEBUG const uint32_t* _getRewriteArray() const noexcept { return &_baseInst._extraReg._id; } + + //! Maximum value of rewrite id - 6 operands each having 4 slots is 24, one RegOnly having 2 slots => 26. + static constexpr uint32_t kMaxRewriteId = 26 - 1; + + //! Returns a rewrite index of the given pointer to `id`. + //! + //! This function returns a value that can be then passed to `\ref rewriteIdAtIndex() function. It can address + //! any id from any operand that is used by the instruction in addition to \ref BaseInst::regOnly field, which + //! can also be used by the register allocator. + inline uint32_t getRewriteIndex(const uint32_t* id) const noexcept { + const uint32_t* array = _getRewriteArray(); + ASMJIT_ASSERT(array <= id); + + size_t index = (size_t)(id - array); + ASMJIT_ASSERT(index <= kMaxRewriteId); + + return uint32_t(index); + } + + //! Rewrites the given `index` to the provided identifier `id`. + //! + //! \note This is an internal function that is used by a \ref BaseCompiler implementation to rewrite virtual + //! registers to physical registers. The rewriter in this case sees all operands as array of uint32 values + //! and the given `index` describes a position in this array. For example a single \ref Operand would be + //! decomposed to 4 uint32_t values, where the first at index 0 would be operand signature, next would be + //! base id, etc... This is a comfortable way of patching operands without having to check for their types. + inline void rewriteIdAtIndex(uint32_t index, uint32_t id) noexcept { + ASMJIT_ASSERT(index <= kMaxRewriteId); + + uint32_t* array = _getRewriteArray(); + array[index] = id; + } + //! \endcond + + //! \} + + //! \name Static Functions + //! \{ + + //! \cond INTERNAL + + //! Returns the capacity required for the given operands count `opCount`. + //! + //! There are only two capacities used - \ref kBaseOpCapacity and \ref kFullOpCapacity, so this function + //! is used to decide between these two. The general rule is that instructions that can be represented with + //! \ref kBaseOpCapacity would use this value, and all others would take \ref kFullOpCapacity. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t capacityOfOpCount(uint32_t opCount) noexcept { + return opCount <= kBaseOpCapacity ? kBaseOpCapacity : kFullOpCapacity; + } + + //! Calculates the size of \ref InstNode required to hold at most `opCapacity` operands. + //! + //! This function is used internally to allocate \ref InstNode. + static ASMJIT_INLINE_NODEBUG constexpr size_t nodeSizeOfOpCapacity(uint32_t opCapacity) noexcept { + return sizeof(InstNode) + opCapacity * sizeof(Operand); + } + //! \endcond + + //! \} +}; + +//! Instruction node with embedded operands following \ref InstNode layout. +//! +//! \note This is used to make tools such as static analysis and compilers happy about the layout. There were two +//! instruction nodes in the past, having the second extend the operand array of the first, but that has caused +//! undefined behavior and made recent tools unhappy about that. +template +class InstNodeWithOperands : public InstNode { +public: + Operand_ _operands[kN]; + + //! Creates a new `InstNodeWithOperands` instance. + ASMJIT_INLINE_NODEBUG InstNodeWithOperands(BaseBuilder* cb, InstId instId, InstOptions options, uint32_t opCount) noexcept + : InstNode(cb, instId, options, opCount, kN) {} +}; + +//! Section node. +class SectionNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(SectionNode) + + //! \name Members + //! \{ + + //! Section id. + uint32_t _id; + + //! Next section node that follows this section. + //! + //! This link is only valid when the section is active (is part of the code) and when `Builder::hasDirtySectionLinks()` + //! returns `false`. If you intend to use this field you should always call `Builder::updateSectionLinks()` before you + //! do so. + SectionNode* _nextSection; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `SectionNode` instance. + ASMJIT_INLINE_NODEBUG SectionNode(BaseBuilder* cb, uint32_t sectionId = 0) noexcept + : BaseNode(cb, NodeType::kSection, NodeFlags::kHasNoEffect), + _id(sectionId), + _nextSection(nullptr) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the section id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + + //! \} +}; + +//! Label node. +class LabelNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(LabelNode) + + //! \name Members + //! \{ + + //! Label identifier. + uint32_t _labelId; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `LabelNode` instance. + ASMJIT_INLINE_NODEBUG LabelNode(BaseBuilder* cb, uint32_t labelId = 0) noexcept + : BaseNode(cb, NodeType::kLabel, NodeFlags::kHasNoEffect | NodeFlags::kActsAsLabel), + _labelId(labelId) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns \ref Label representation of the \ref LabelNode. + ASMJIT_INLINE_NODEBUG Label label() const noexcept { return Label(_labelId); } + //! Returns the id of the label. + ASMJIT_INLINE_NODEBUG uint32_t labelId() const noexcept { return _labelId; } + + //! \} +}; + +//! Align directive (BaseBuilder). +//! +//! Wraps `.align` directive. +class AlignNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(AlignNode) + + //! \name Members + //! \{ + + //! Alignment (in bytes). + uint32_t _alignment; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `AlignNode` instance. + ASMJIT_INLINE_NODEBUG AlignNode(BaseBuilder* cb, AlignMode alignMode, uint32_t alignment) noexcept + : BaseNode(cb, NodeType::kAlign, NodeFlags::kIsCode | NodeFlags::kHasNoEffect) { + + _alignData._alignMode = alignMode; + _alignment = alignment; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns align mode. + ASMJIT_INLINE_NODEBUG AlignMode alignMode() const noexcept { return _alignData._alignMode; } + //! Sets align mode to `alignMode`. + ASMJIT_INLINE_NODEBUG void setAlignMode(AlignMode alignMode) noexcept { _alignData._alignMode = alignMode; } + + //! Returns align offset in bytes. + ASMJIT_INLINE_NODEBUG uint32_t alignment() const noexcept { return _alignment; } + //! Sets align offset in bytes to `offset`. + ASMJIT_INLINE_NODEBUG void setAlignment(uint32_t alignment) noexcept { _alignment = alignment; } + + //! \} +}; + +//! Embed data node. +//! +//! Wraps `.data` directive. The node contains data that will be placed at the node's position in the assembler +//! stream. The data is considered to be RAW; no analysis nor byte-order conversion is performed on RAW data. +class EmbedDataNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(EmbedDataNode) + + //! \cond INTERNAL + enum : uint32_t { + kInlineBufferSize = 128 - (sizeof(BaseNode) + sizeof(size_t) * 2) + }; + //! \endcond + + //! \name Members + //! \{ + + size_t _itemCount; + size_t _repeatCount; + + union { + uint8_t* _externalData; + uint8_t _inlineData[kInlineBufferSize]; + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `EmbedDataNode` instance. + ASMJIT_INLINE_NODEBUG EmbedDataNode(BaseBuilder* cb) noexcept + : BaseNode(cb, NodeType::kEmbedData, NodeFlags::kIsData), + _itemCount(0), + _repeatCount(0) { + _embed._typeId = TypeId::kUInt8; + _embed._typeSize = uint8_t(1); + memset(_inlineData, 0, kInlineBufferSize); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns data type as \ref TypeId. + ASMJIT_INLINE_NODEBUG TypeId typeId() const noexcept { return _embed._typeId; } + //! Returns the size of a single data element. + ASMJIT_INLINE_NODEBUG uint32_t typeSize() const noexcept { return _embed._typeSize; } + + //! Returns a pointer to the data casted to `uint8_t`. + ASMJIT_INLINE_NODEBUG uint8_t* data() const noexcept { + return dataSize() <= kInlineBufferSize ? const_cast(_inlineData) : _externalData; + } + + //! Returns a pointer to the data casted to `T`. + template + ASMJIT_INLINE_NODEBUG T* dataAs() const noexcept { return reinterpret_cast(data()); } + + //! Returns the number of (typed) items in the array. + ASMJIT_INLINE_NODEBUG size_t itemCount() const noexcept { return _itemCount; } + + //! Returns how many times the data is repeated (default 1). + //! + //! Repeated data is useful when defining constants for SIMD, for example. + ASMJIT_INLINE_NODEBUG size_t repeatCount() const noexcept { return _repeatCount; } + + //! Returns the size of the data, not considering the number of times it repeats. + //! + //! \note The returned value is the same as `typeSize() * itemCount()`. + ASMJIT_INLINE_NODEBUG size_t dataSize() const noexcept { return typeSize() * _itemCount; } + + //! \} +}; + +//! Label data node. +class EmbedLabelNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(EmbedLabelNode) + + //! \name Members + //! \{ + + uint32_t _labelId; + uint32_t _dataSize; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `EmbedLabelNode` instance. + ASMJIT_INLINE_NODEBUG EmbedLabelNode(BaseBuilder* cb, uint32_t labelId = 0, uint32_t dataSize = 0) noexcept + : BaseNode(cb, NodeType::kEmbedLabel, NodeFlags::kIsData), + _labelId(labelId), + _dataSize(dataSize) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the label to embed as \ref Label operand. + ASMJIT_INLINE_NODEBUG Label label() const noexcept { return Label(_labelId); } + //! Returns the id of the label. + ASMJIT_INLINE_NODEBUG uint32_t labelId() const noexcept { return _labelId; } + + //! Sets the label id from `label` operand. + ASMJIT_INLINE_NODEBUG void setLabel(const Label& label) noexcept { setLabelId(label.id()); } + //! Sets the label id (use with caution, improper use can break a lot of things). + ASMJIT_INLINE_NODEBUG void setLabelId(uint32_t labelId) noexcept { _labelId = labelId; } + + //! Returns the data size. + ASMJIT_INLINE_NODEBUG uint32_t dataSize() const noexcept { return _dataSize; } + //! Sets the data size. + ASMJIT_INLINE_NODEBUG void setDataSize(uint32_t dataSize) noexcept { _dataSize = dataSize; } + + //! \} +}; + +//! Label data node. +class EmbedLabelDeltaNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(EmbedLabelDeltaNode) + + //! \name Members + //! \{ + + uint32_t _labelId; + uint32_t _baseLabelId; + uint32_t _dataSize; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `EmbedLabelDeltaNode` instance. + ASMJIT_INLINE_NODEBUG EmbedLabelDeltaNode(BaseBuilder* cb, uint32_t labelId = 0, uint32_t baseLabelId = 0, uint32_t dataSize = 0) noexcept + : BaseNode(cb, NodeType::kEmbedLabelDelta, NodeFlags::kIsData), + _labelId(labelId), + _baseLabelId(baseLabelId), + _dataSize(dataSize) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the label as `Label` operand. + ASMJIT_INLINE_NODEBUG Label label() const noexcept { return Label(_labelId); } + //! Returns the id of the label. + ASMJIT_INLINE_NODEBUG uint32_t labelId() const noexcept { return _labelId; } + + //! Sets the label id from `label` operand. + ASMJIT_INLINE_NODEBUG void setLabel(const Label& label) noexcept { setLabelId(label.id()); } + //! Sets the label id. + ASMJIT_INLINE_NODEBUG void setLabelId(uint32_t labelId) noexcept { _labelId = labelId; } + + //! Returns the base label as `Label` operand. + ASMJIT_INLINE_NODEBUG Label baseLabel() const noexcept { return Label(_baseLabelId); } + //! Returns the id of the base label. + ASMJIT_INLINE_NODEBUG uint32_t baseLabelId() const noexcept { return _baseLabelId; } + + //! Sets the base label id from `label` operand. + ASMJIT_INLINE_NODEBUG void setBaseLabel(const Label& baseLabel) noexcept { setBaseLabelId(baseLabel.id()); } + //! Sets the base label id. + ASMJIT_INLINE_NODEBUG void setBaseLabelId(uint32_t baseLabelId) noexcept { _baseLabelId = baseLabelId; } + + //! Returns the size of the embedded label address. + ASMJIT_INLINE_NODEBUG uint32_t dataSize() const noexcept { return _dataSize; } + //! Sets the size of the embedded label address. + ASMJIT_INLINE_NODEBUG void setDataSize(uint32_t dataSize) noexcept { _dataSize = dataSize; } + + //! \} +}; + +//! A node that wraps `ConstPool`. +class ConstPoolNode : public LabelNode { +public: + ASMJIT_NONCOPYABLE(ConstPoolNode) + + //! \name Members + //! \{ + + ConstPool _constPool; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `ConstPoolNode` instance. + ASMJIT_INLINE_NODEBUG ConstPoolNode(BaseBuilder* cb, uint32_t id = 0) noexcept + : LabelNode(cb, id), + _constPool(&cb->_codeZone) { + + setType(NodeType::kConstPool); + addFlags(NodeFlags::kIsData); + clearFlags(NodeFlags::kIsCode | NodeFlags::kHasNoEffect); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the constant-pool is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _constPool.empty(); } + //! Returns the size of the constant-pool in bytes. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _constPool.size(); } + //! Returns minimum alignment. + ASMJIT_INLINE_NODEBUG size_t alignment() const noexcept { return _constPool.alignment(); } + + //! Returns the wrapped `ConstPool` instance. + ASMJIT_INLINE_NODEBUG ConstPool& constPool() noexcept { return _constPool; } + //! Returns the wrapped `ConstPool` instance (const). + ASMJIT_INLINE_NODEBUG const ConstPool& constPool() const noexcept { return _constPool; } + + //! \} + + //! \name Utilities + //! \{ + + //! See `ConstPool::add()`. + ASMJIT_INLINE_NODEBUG Error add(const void* data, size_t size, size_t& dstOffset) noexcept { + return _constPool.add(data, size, dstOffset); + } + + //! \} +}; + +//! Comment node. +class CommentNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(CommentNode) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `CommentNode` instance. + ASMJIT_INLINE_NODEBUG CommentNode(BaseBuilder* cb, const char* comment) noexcept + : BaseNode(cb, NodeType::kComment, NodeFlags::kIsInformative | NodeFlags::kHasNoEffect | NodeFlags::kIsRemovable) { + _inlineComment = comment; + } + + //! \} +}; + +//! Sentinel node. +//! +//! Sentinel is a marker that is completely ignored by the code builder. It's used to remember a position in a code +//! as it never gets removed by any pass. +class SentinelNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(SentinelNode) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `SentinelNode` instance. + ASMJIT_INLINE_NODEBUG SentinelNode(BaseBuilder* cb, SentinelType sentinelType = SentinelType::kUnknown) noexcept + : BaseNode(cb, NodeType::kSentinel, NodeFlags::kIsInformative | NodeFlags::kHasNoEffect) { + + _sentinel._sentinelType = sentinelType; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the type of the sentinel. + ASMJIT_INLINE_NODEBUG SentinelType sentinelType() const noexcept { + return _sentinel._sentinelType; + } + + //! Sets the type of the sentinel. + ASMJIT_INLINE_NODEBUG void setSentinelType(SentinelType type) noexcept { + _sentinel._sentinelType = type; + } + + //! \} +}; + +//! Pass can be used to implement code transformations, analysis, and lowering. +class ASMJIT_VIRTAPI Pass { +public: + ASMJIT_BASE_CLASS(Pass) + ASMJIT_NONCOPYABLE(Pass) + + //! \name Members + //! \{ + + //! BaseBuilder this pass is assigned to. + BaseBuilder* _cb = nullptr; + //! Name of the pass. + const char* _name = nullptr; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API Pass(const char* name) noexcept; + ASMJIT_API virtual ~Pass() noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns \ref BaseBuilder associated with the pass. + ASMJIT_INLINE_NODEBUG const BaseBuilder* cb() const noexcept { return _cb; } + //! Returns the name of the pass. + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name; } + + //! \} + + //! \name Pass Interface + //! \{ + + //! Processes the code stored in Builder or Compiler. + //! + //! This is the only function that is called by the `BaseBuilder` to process the code. It passes `zone`, + //! which will be reset after the `run()` finishes. + ASMJIT_API virtual Error run(Zone* zone, Logger* logger); + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // !ASMJIT_NO_BUILDER +#endif // ASMJIT_CORE_BUILDER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/codebuffer.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/codebuffer.h new file mode 100644 index 0000000000000000000000000000000000000000..d4b7cebbcef40f0df8dd8705348526f8cb5b7ca7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/codebuffer.h @@ -0,0 +1,113 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CODEBUFFER_H_INCLUDED +#define ASMJIT_CORE_CODEBUFFER_H_INCLUDED + +#include "../core/globals.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Flags used by \ref CodeBuffer. +enum class CodeBufferFlags : uint32_t { + //! No flags. + kNone = 0, + //! Buffer is external (not allocated by asmjit). + kIsExternal = 0x00000001u, + //! Buffer is fixed (cannot be reallocated). + kIsFixed = 0x00000002u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CodeBufferFlags) + +//! Code or data buffer. +struct CodeBuffer { + //! \name Members + //! \{ + + //! The content of the buffer (data). + uint8_t* _data; + //! Number of bytes of `data` used. + size_t _size; + //! Buffer capacity (in bytes). + size_t _capacity; + //! Buffer flags. + CodeBufferFlags _flags; + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Returns a reference to the byte at the given `index`. + inline uint8_t& operator[](size_t index) noexcept { + ASMJIT_ASSERT(index < _size); + return _data[index]; + } + //! \overload + inline const uint8_t& operator[](size_t index) const noexcept { + ASMJIT_ASSERT(index < _size); + return _data[index]; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns code buffer flags. + ASMJIT_INLINE_NODEBUG CodeBufferFlags flags() const noexcept { return _flags; } + //! Tests whether the code buffer has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(CodeBufferFlags flag) const noexcept { return Support::test(_flags, flag); } + + //! Tests whether this code buffer has a fixed size. + //! + //! Fixed size means that the code buffer is fixed and cannot grow. + ASMJIT_INLINE_NODEBUG bool isFixed() const noexcept { return hasFlag(CodeBufferFlags::kIsFixed); } + + //! Tests whether the data in this code buffer is external. + //! + //! External data can only be provided by users, it's never used by AsmJit. + ASMJIT_INLINE_NODEBUG bool isExternal() const noexcept { return hasFlag(CodeBufferFlags::kIsExternal); } + + //! Tests whether the data in this code buffer is allocated (non-null). + ASMJIT_INLINE_NODEBUG bool isAllocated() const noexcept { return _data != nullptr; } + + //! Tests whether the code buffer is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return !_size; } + + //! Returns the size of the data. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + //! Returns the capacity of the data. + ASMJIT_INLINE_NODEBUG size_t capacity() const noexcept { return _capacity; } + + //! Returns the pointer to the data the buffer references. + ASMJIT_INLINE_NODEBUG uint8_t* data() noexcept { return _data; } + //! \overload + ASMJIT_INLINE_NODEBUG const uint8_t* data() const noexcept { return _data; } + + //! \} + + //! \name Iterators + //! \{ + + ASMJIT_INLINE_NODEBUG uint8_t* begin() noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const uint8_t* begin() const noexcept { return _data; } + + ASMJIT_INLINE_NODEBUG uint8_t* end() noexcept { return _data + _size; } + ASMJIT_INLINE_NODEBUG const uint8_t* end() const noexcept { return _data + _size; } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CODEBUFFER_H_INCLUDED + diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/codeholder.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/codeholder.h new file mode 100644 index 0000000000000000000000000000000000000000..3f2d1d7074521b340fda00a7f46b446e3d0a648c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/codeholder.h @@ -0,0 +1,1123 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CODEHOLDER_H_INCLUDED +#define ASMJIT_CORE_CODEHOLDER_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/codebuffer.h" +#include "../core/errorhandler.h" +#include "../core/operand.h" +#include "../core/string.h" +#include "../core/support.h" +#include "../core/target.h" +#include "../core/zone.h" +#include "../core/zonehash.h" +#include "../core/zonestring.h" +#include "../core/zonetree.h" +#include "../core/zonevector.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +class BaseEmitter; +class CodeHolder; +class LabelEntry; +class Logger; + +//! Operator type that can be used within an \ref Expression. +enum class ExpressionOpType : uint8_t { + //! Addition. + kAdd = 0, + //! Subtraction. + kSub = 1, + //! Multiplication + kMul = 2, + //! Logical left shift. + kSll = 3, + //! Logical right shift. + kSrl = 4, + //! Arithmetic right shift. + kSra = 5 +}; + +//! Value type that can be used within an \ref Expression. +enum class ExpressionValueType : uint8_t { + //! No value or invalid. + kNone = 0, + //! Value is 64-bit unsigned integer (constant). + kConstant = 1, + //! Value is \ref LabelEntry, which references a \ref Label. + kLabel = 2, + //! Value is \ref Expression + kExpression = 3 +}; + +//! Expression node that can reference constants, labels, and another expressions. +struct Expression { + //! Expression value. + union Value { + //! Constant. + uint64_t constant; + //! Pointer to another expression. + Expression* expression; + //! Pointer to \ref LabelEntry. + LabelEntry* label; + }; + + //! \name Members + //! \{ + + //! Operation type. + ExpressionOpType opType; + //! Value types of \ref value. + ExpressionValueType valueType[2]; + //! Reserved for future use, should be initialized to zero. + uint8_t reserved[5]; + //! Expression left and right values. + Value value[2]; + + //! \} + + //! \name Accessors + //! \{ + + //! Resets the whole expression. + //! + //! Changes both values to \ref ExpressionValueType::kNone. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = Expression{}; } + + //! Sets the value type at `index` to \ref ExpressionValueType::kConstant and its content to `constant`. + ASMJIT_INLINE_NODEBUG void setValueAsConstant(size_t index, uint64_t constant) noexcept { + valueType[index] = ExpressionValueType::kConstant; + value[index].constant = constant; + } + + //! Sets the value type at `index` to \ref ExpressionValueType::kLabel and its content to `labelEntry`. + ASMJIT_INLINE_NODEBUG void setValueAsLabel(size_t index, LabelEntry* labelEntry) noexcept { + valueType[index] = ExpressionValueType::kLabel; + value[index].label = labelEntry; + } + + //! Sets the value type at `index` to \ref ExpressionValueType::kExpression and its content to `expression`. + ASMJIT_INLINE_NODEBUG void setValueAsExpression(size_t index, Expression* expression) noexcept { + valueType[index] = ExpressionValueType::kExpression; + value[index].expression = expression; + } + + //! \} +}; + +//! Section flags, used by \ref Section. +enum class SectionFlags : uint32_t { + //! No flags. + kNone = 0, + //! Executable (.text sections). + kExecutable = 0x00000001u, + //! Read-only (.text and .data sections). + kReadOnly = 0x00000002u, + //! Zero initialized by the loader (BSS). + kZeroInitialized = 0x00000004u, + //! Info / comment flag. + kComment = 0x00000008u, + //! Section created implicitly, can be deleted by \ref Target. + kImplicit = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(SectionFlags) + +//! Flags that can be used with \ref CodeHolder::copySectionData() and \ref CodeHolder::copyFlattenedData(). +enum class CopySectionFlags : uint32_t { + //! No flags. + kNone = 0, + + //! If virtual size of a section is greater than the size of its \ref CodeBuffer then all bytes between the buffer + //! size and virtual size will be zeroed. If this option is not set then those bytes would be left as is, which + //! means that if the user didn't initialize them they would have a previous content, which may be unwanted. + kPadSectionBuffer = 0x00000001u, + + //! Clears the target buffer if the flattened data is less than the destination size. This option works + //! only with \ref CodeHolder::copyFlattenedData() as it processes multiple sections. It is ignored by + //! \ref CodeHolder::copySectionData(). + kPadTargetBuffer = 0x00000002u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CopySectionFlags) + +//! Section entry. +class Section { +public: + //! \name Members + //! \{ + + //! Section id. + uint32_t _id; + //! Section flags. + SectionFlags _flags; + //! Section alignment requirements (0 if no requirements). + uint32_t _alignment; + //! Order (lower value means higher priority). + int32_t _order; + //! Offset of this section from base-address. + uint64_t _offset; + //! Virtual size of the section (zero initialized sections). + uint64_t _virtualSize; + //! Section name (max 35 characters, PE allows max 8). + FixedString _name; + //! Code or data buffer. + CodeBuffer _buffer; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the section id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + //! Returns the section name, as a null terminated string. + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name.str; } + + //! Returns the section data. + ASMJIT_INLINE_NODEBUG uint8_t* data() noexcept { return _buffer.data(); } + //! \overload + ASMJIT_INLINE_NODEBUG const uint8_t* data() const noexcept { return _buffer.data(); } + + //! Returns the section flags. + ASMJIT_INLINE_NODEBUG SectionFlags flags() const noexcept { return _flags; } + //! Tests whether the section has the given `flag`. + ASMJIT_INLINE_NODEBUG bool hasFlag(SectionFlags flag) const noexcept { return Support::test(_flags, flag); } + //! Adds `flags` to the section flags. + ASMJIT_INLINE_NODEBUG void addFlags(SectionFlags flags) noexcept { _flags |= flags; } + //! Removes `flags` from the section flags. + ASMJIT_INLINE_NODEBUG void clearFlags(SectionFlags flags) noexcept { _flags &= ~flags; } + + //! Returns the minimum section alignment + ASMJIT_INLINE_NODEBUG uint32_t alignment() const noexcept { return _alignment; } + //! Sets the minimum section alignment + ASMJIT_INLINE_NODEBUG void setAlignment(uint32_t alignment) noexcept { _alignment = alignment; } + + //! Returns the section order, which has a higher priority than section id. + ASMJIT_INLINE_NODEBUG int32_t order() const noexcept { return _order; } + + //! Returns the section offset, relative to base. + ASMJIT_INLINE_NODEBUG uint64_t offset() const noexcept { return _offset; } + //! Set the section offset. + ASMJIT_INLINE_NODEBUG void setOffset(uint64_t offset) noexcept { _offset = offset; } + + //! Returns the virtual size of the section. + //! + //! Virtual size is initially zero and is never changed by AsmJit. It's normal if virtual size is smaller than + //! size returned by `bufferSize()` as the buffer stores real data emitted by assemblers or appended by users. + //! + //! Use `realSize()` to get the real and final size of this section. + ASMJIT_INLINE_NODEBUG uint64_t virtualSize() const noexcept { return _virtualSize; } + //! Sets the virtual size of the section. + ASMJIT_INLINE_NODEBUG void setVirtualSize(uint64_t virtualSize) noexcept { _virtualSize = virtualSize; } + + //! Returns the buffer size of the section. + ASMJIT_INLINE_NODEBUG size_t bufferSize() const noexcept { return _buffer.size(); } + //! Returns the real size of the section calculated from virtual and buffer sizes. + ASMJIT_INLINE_NODEBUG uint64_t realSize() const noexcept { return Support::max(virtualSize(), bufferSize()); } + + //! Returns the `CodeBuffer` used by this section. + ASMJIT_INLINE_NODEBUG CodeBuffer& buffer() noexcept { return _buffer; } + //! Returns the `CodeBuffer` used by this section (const). + ASMJIT_INLINE_NODEBUG const CodeBuffer& buffer() const noexcept { return _buffer; } + + //! \} +}; + +//! Entry in an address table. +class AddressTableEntry : public ZoneTreeNodeT { +public: + ASMJIT_NONCOPYABLE(AddressTableEntry) + + //! \name Members + //! \{ + + //! Address. + uint64_t _address; + //! Slot. + uint32_t _slot; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG explicit AddressTableEntry(uint64_t address) noexcept + : _address(address), + _slot(0xFFFFFFFFu) {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG uint64_t address() const noexcept { return _address; } + ASMJIT_INLINE_NODEBUG uint32_t slot() const noexcept { return _slot; } + + ASMJIT_INLINE_NODEBUG bool hasAssignedSlot() const noexcept { return _slot != 0xFFFFFFFFu; } + + ASMJIT_INLINE_NODEBUG bool operator<(const AddressTableEntry& other) const noexcept { return _address < other._address; } + ASMJIT_INLINE_NODEBUG bool operator>(const AddressTableEntry& other) const noexcept { return _address > other._address; } + + ASMJIT_INLINE_NODEBUG bool operator<(uint64_t queryAddress) const noexcept { return _address < queryAddress; } + ASMJIT_INLINE_NODEBUG bool operator>(uint64_t queryAddress) const noexcept { return _address > queryAddress; } + + //! \} +}; + +//! Offset format type, used by \ref OffsetFormat. +enum class OffsetType : uint8_t { + // Common Offset Formats + // --------------------- + + //! A value having `_immBitCount` bits and shifted by `_immBitShift`. + //! + //! This offset type is sufficient for many targets that store offset as a continuous set bits within an + //! instruction word / sequence of bytes. + kSignedOffset, + + //! An unsigned value having `_immBitCount` bits and shifted by `_immBitShift`. + kUnsignedOffset, + + // AArch64 Specific Offset Formats + // ------------------------------- + + //! AArch64 ADR format of `[.|immlo:2|.....|immhi:19|.....]`. + kAArch64_ADR, + + //! AArch64 ADRP format of `[.|immlo:2|.....|immhi:19|.....]` (4kB pages). + kAArch64_ADRP, + + // AArch32 Specific Offset Formats (T16 & T32) + // ------------------------------------------- + + //! AArch32 THUMBv2 immediate encoding of 'ADR' instruction (12-bit payload and sign bit): + //! + //! `|.....|imm:1|..N.N|......|imm:3|....|imm:8|` + //! + //! Where `N` is one if the offset is negative. The immediate is encoded as absolute value of the offset if negative. + kThumb32_ADR, + + //! AArch32 THUMBv2 immediate encoding of 'BLX' instruction (23-bit immediate payload, multiplied by 4): + //! + //! `|.....|imm[22]|imm[19:10]|..|ja|1|jb|imm[9:0]|0` + //! + //! Where: + //! + //! - `ja` is calculated as imm[22] ^ imm[21] ^ 1. + //! - `jb` is calculated as imm[22] ^ imm[20] ^ 1. + kThumb32_BLX, + + //! AArch32 THUMBv2 immediate encoding of 'B' instruction without `` (24-bit immediate payload, multiplied by 2): + //! + //! `|.....|imm[23]|imm[20:11]|..|ja|1|jb|imm[10:0]` + //! + //! Where: + //! + //! - `ja` is calculated as imm[23] ^ imm[22] ^ 1. + //! - `jb` is calculated as imm[23] ^ imm[21] ^ 1. + kThumb32_B, + + //! AArch32 THUMBv2 immediate encoding of 'B' instruction with `` (20-bit immediate payload, multiplied by 2). + //! + //! `|.....|imm[19]|....|imm[16:11]|..|ja|1|jb|imm[10:0]` + //! + //! Where: + //! + //! - `ja` is calculated as imm[19] ^ imm[18] ^ 1. + //! - `jb` is calculated as imm[19] ^ imm[17] ^ 1. + kThumb32_BCond, + + // AArch32 Specific Offset Formats (A32) + // ------------------------------------- + + //! AArch32 ADR instruction, which uses a standard 12-bit immediate encoding that is used by other ARM instructions. + kAArch32_ADR, + + //! AArch32 signed offset that is similar to `kSignedOffset`, however it uses absolute value of the offset and its + //! sign is encoded in 23rd bit of the opcode. + //! + //! `|........|U.......|........|........|` + //! + kAArch32_U23_SignedOffset, + + //! AArch32 offset format that encodes 8-bit offset as: + //! + //! `|........|U.......|....|imm[7:4]|....|imm[3:0]|` + //! + //! in a 32-bit word, where U is a sign of the displacement and the displacement itself is encoded as its absolute + //! value. + kAArch32_U23_0To3At0_4To7At8, + + //! AArch32 offset format that encodes a signed 25-bit offset as: + //! + //! `|.......|imm[0]|imm[24:1]|` + //! + //! in a 32-bit word. + kAArch32_1To24At0_0At24, + + //! Maximum value of `OffsetFormatType`. + kMaxValue = kAArch32_1To24At0_0At24 +}; + +//! Provides information about formatting offsets, absolute addresses, or their parts. Offset format is used by both +//! \ref RelocEntry and \ref LabelLink. The illustration below describes the relation of region size and offset size. +//! Region size is the size of the whole unit whereas offset size is the size of the unit that will be patched. +//! +//! ``` +//! +-> Code buffer | The subject of the relocation (region) | +//! | | (Word-Offset) (Word-Size) | +//! |xxxxxxxxxxxxxxx|................|*PATCHED*|................|xxxxxxxxxxxx-> +//! | | +//! [Word Offset points here]----+ +--- [WordOffset + WordSize] +//! ``` +//! +//! Once the offset word has been located it can be patched like this: +//! +//! ``` +//! |ImmDiscardLSB (discard LSB bits). +//! |.. +//! [0000000000000iiiiiiiiiiiiiiiiiDD] - Offset value (32-bit) +//! [000000000000000iiiiiiiiiiiiiiiii] - Offset value after discard LSB. +//! [00000000000iiiiiiiiiiiiiiiii0000] - Offset value shifted by ImmBitShift. +//! [xxxxxxxxxxxiiiiiiiiiiiiiiiiixxxx] - Patched word (32-bit) +//! |...............| +//! (ImmBitCount) +- ImmBitShift +//! ``` +struct OffsetFormat { + //! \name Members + //! \{ + + //! Type of the offset. + OffsetType _type; + //! Encoding flags. + uint8_t _flags; + //! Size of the region (in bytes) containing the offset value, if the offset value is part of an instruction, + //! otherwise it would be the same as `_valueSize`. + uint8_t _regionSize; + //! Size of the offset value, in bytes (1, 2, 4, or 8). + uint8_t _valueSize; + //! Offset of the offset value, in bytes, relative to the start of the region or data. Value offset would be + //! zero if both region size and value size are equal. + uint8_t _valueOffset; + //! Size of the offset immediate value in bits. + uint8_t _immBitCount; + //! Shift of the offset immediate value in bits in the target word. + uint8_t _immBitShift; + //! Number of least significant bits to discard before writing the immediate to the destination. All discarded + //! bits must be zero otherwise the value is invalid. + uint8_t _immDiscardLsb; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the type of the offset. + ASMJIT_INLINE_NODEBUG OffsetType type() const noexcept { return _type; } + + //! Returns whether the offset is encoded as an absolute value of the offset with additional field(s) that represent + //! the sign (AArch32 U/N fields in the opcode). + //! + //! If true, the offset itself is always positive and a separate U/N field is used to indicate the sign of the offset + //! (usually `U==1` means ADD, but sometimes `N==1` means negative offset, which implies SUB). + ASMJIT_INLINE_NODEBUG bool hasSignBit() const noexcept { + return _type == OffsetType::kThumb32_ADR || + _type == OffsetType::kAArch32_ADR || + _type == OffsetType::kAArch32_U23_SignedOffset || + _type == OffsetType::kAArch32_U23_0To3At0_4To7At8; + } + + //! Returns flags. + ASMJIT_INLINE_NODEBUG uint32_t flags() const noexcept { return _flags; } + //! Returns the size of the region/instruction where the offset is encoded. + ASMJIT_INLINE_NODEBUG uint32_t regionSize() const noexcept { return _regionSize; } + //! Returns the offset of the word relative to the start of the region where the offset is. + ASMJIT_INLINE_NODEBUG uint32_t valueOffset() const noexcept { return _valueOffset; } + //! Returns the size of the data-type (word) that contains the offset, in bytes. + ASMJIT_INLINE_NODEBUG uint32_t valueSize() const noexcept { return _valueSize; } + //! Returns the count of bits of the offset value in the data it's stored in. + ASMJIT_INLINE_NODEBUG uint32_t immBitCount() const noexcept { return _immBitCount; } + //! Returns the bit-shift of the offset value in the data it's stored in. + ASMJIT_INLINE_NODEBUG uint32_t immBitShift() const noexcept { return _immBitShift; } + //! Returns the number of least significant bits of the offset value, that must be zero and that are not part of + //! the encoded data. + ASMJIT_INLINE_NODEBUG uint32_t immDiscardLsb() const noexcept { return _immDiscardLsb; } + + //! Resets this offset format to a simple data value of `dataSize` bytes. + //! + //! The region will be the same size as data and immediate bits would correspond to `dataSize * 8`. There will be + //! no immediate bit shift or discarded bits. + inline void resetToSimpleValue(OffsetType type, size_t valueSize) noexcept { + ASMJIT_ASSERT(valueSize <= 8u); + + _type = type; + _flags = uint8_t(0); + _regionSize = uint8_t(valueSize); + _valueSize = uint8_t(valueSize); + _valueOffset = uint8_t(0); + _immBitCount = uint8_t(valueSize * 8u); + _immBitShift = uint8_t(0); + _immDiscardLsb = uint8_t(0); + } + + inline void resetToImmValue(OffsetType type, size_t valueSize, uint32_t immBitShift, uint32_t immBitCount, uint32_t immDiscardLsb) noexcept { + ASMJIT_ASSERT(valueSize <= 8u); + ASMJIT_ASSERT(immBitShift < valueSize * 8u); + ASMJIT_ASSERT(immBitCount <= 64u); + ASMJIT_ASSERT(immDiscardLsb <= 64u); + + _type = type; + _flags = uint8_t(0); + _regionSize = uint8_t(valueSize); + _valueSize = uint8_t(valueSize); + _valueOffset = uint8_t(0); + _immBitCount = uint8_t(immBitCount); + _immBitShift = uint8_t(immBitShift); + _immDiscardLsb = uint8_t(immDiscardLsb); + } + + inline void setRegion(size_t regionSize, size_t valueOffset) noexcept { + _regionSize = uint8_t(regionSize); + _valueOffset = uint8_t(valueOffset); + } + + inline void setLeadingAndTrailingSize(size_t leadingSize, size_t trailingSize) noexcept { + _regionSize = uint8_t(leadingSize + trailingSize + _valueSize); + _valueOffset = uint8_t(leadingSize); + } + + //! \} +}; + +//! Relocation type. +enum class RelocType : uint32_t { + //! None/deleted (no relocation). + kNone = 0, + //! Expression evaluation, `_payload` is pointer to `Expression`. + kExpression = 1, + //! Relocate absolute to absolute. + kAbsToAbs = 2, + //! Relocate relative to absolute. + kRelToAbs = 3, + //! Relocate absolute to relative. + kAbsToRel = 4, + //! Relocate absolute to relative or use trampoline. + kX64AddressEntry = 5 +}; + +//! Relocation entry. +struct RelocEntry { + //! \name Members + //! \{ + + //! Relocation id. + uint32_t _id; + //! Type of the relocation. + RelocType _relocType; + //! Format of the relocated value. + OffsetFormat _format; + //! Source section id. + uint32_t _sourceSectionId; + //! Target section id. + uint32_t _targetSectionId; + //! Source offset (relative to start of the section). + uint64_t _sourceOffset; + //! Payload (target offset, target address, expression, etc). + uint64_t _payload; + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + + ASMJIT_INLINE_NODEBUG RelocType relocType() const noexcept { return _relocType; } + ASMJIT_INLINE_NODEBUG const OffsetFormat& format() const noexcept { return _format; } + + ASMJIT_INLINE_NODEBUG uint32_t sourceSectionId() const noexcept { return _sourceSectionId; } + ASMJIT_INLINE_NODEBUG uint32_t targetSectionId() const noexcept { return _targetSectionId; } + + ASMJIT_INLINE_NODEBUG uint64_t sourceOffset() const noexcept { return _sourceOffset; } + ASMJIT_INLINE_NODEBUG uint64_t payload() const noexcept { return _payload; } + + ASMJIT_INLINE_NODEBUG Expression* payloadAsExpression() const noexcept { + return reinterpret_cast(uintptr_t(_payload)); + } + + //! \} +}; + +//! Type of the \ref Label. +enum class LabelType : uint8_t { + //! Anonymous label that can optionally have a name, which is only used for debugging purposes. + kAnonymous = 0, + //! Local label (always has parentId). + kLocal = 1, + //! Global label (never has parentId). + kGlobal = 2, + //! External label (references an external symbol). + kExternal = 3, + + //! Maximum value of `LabelType`. + kMaxValue = kExternal +}; + +//! Data structure used to link either unbound labels or cross-section links. +struct LabelLink { + //! Next link (single-linked list). + LabelLink* next; + //! Section id where the label is bound. + uint32_t sectionId; + //! Relocation id or Globals::kInvalidId. + uint32_t relocId; + //! Label offset relative to the start of the section. + size_t offset; + //! Inlined rel8/rel32. + intptr_t rel; + //! Offset format information. + OffsetFormat format; +}; + +//! Label entry. +//! +//! Contains the following properties: +//! - Label id - This is the only thing that is set to the `Label` operand. +//! - Label name - Optional, used mostly to create executables and libraries. +//! - Label type - Type of the label, default `LabelType::kAnonymous`. +//! - Label parent id - Derived from many assemblers that allow to define a local label that falls under a global +//! label. This allows to define many labels of the same name that have different parent (global) label. +//! - Offset - offset of the label bound by `Assembler`. +//! - Links - single-linked list that contains locations of code that has to be patched when the label gets bound. +//! Every use of unbound label adds one link to `_links` list. +//! - HVal - Hash value of label's name and optionally parentId. +//! - HashNext - Hash-table implementation detail. +class LabelEntry : public ZoneHashNode { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + //! SSO size of \ref _name. + //! + //! \cond INTERNAL + //! Let's round the size of `LabelEntry` to 64 bytes (as `ZoneAllocator` has granularity of 32 bytes anyway). This + //! gives `_name` the remaining space, which is should be 16 bytes on 64-bit and 28 bytes on 32-bit architectures. + //! \endcond + kStaticNameSize = 64 - (sizeof(ZoneHashNode) + 8 + sizeof(Section*) + sizeof(size_t) + sizeof(LabelLink*)) + }; + + //! \} + + //! \name Members + //! \{ + + //! Type of the label. + LabelType _type; + //! Must be zero. + uint8_t _reserved[3]; + //! Label parent id or zero. + uint32_t _parentId; + //! Label offset relative to the start of the `_section`. + uint64_t _offset; + //! Section where the label was bound. + Section* _section; + //! Label links. + LabelLink* _links; + //! Label name. + ZoneString _name; + + //! \} + + //! \name Accessors + //! \{ + + // NOTE: Label id is stored in `_customData`, which is provided by ZoneHashNode to fill a padding that a C++ + // compiler targeting 64-bit CPU will add to align the structure to 64-bits. + + //! Returns label id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _customData; } + //! Sets label id (internal, used only by `CodeHolder`). + ASMJIT_INLINE_NODEBUG void _setId(uint32_t id) noexcept { _customData = id; } + + //! Returns label type. + ASMJIT_INLINE_NODEBUG LabelType type() const noexcept { return _type; } + + //! Tests whether the label has a parent label. + ASMJIT_INLINE_NODEBUG bool hasParent() const noexcept { return _parentId != Globals::kInvalidId; } + //! Returns label's parent id. + ASMJIT_INLINE_NODEBUG uint32_t parentId() const noexcept { return _parentId; } + + //! Returns the section where the label was bound. + //! + //! If the label was not yet bound the return value is `nullptr`. + ASMJIT_INLINE_NODEBUG Section* section() const noexcept { return _section; } + + //! Tests whether the label has name. + ASMJIT_INLINE_NODEBUG bool hasName() const noexcept { return !_name.empty(); } + + //! Returns the label's name. + //! + //! \note Local labels will return their local name without their parent part, for example ".L1". + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name.data(); } + + //! Returns size of label's name. + //! + //! \note Label name is always null terminated, so you can use `strlen()` to get it, however, it's also cached in + //! `LabelEntry` itself, so if you want to know the size the fastest way is to call `LabelEntry::nameSize()`. + ASMJIT_INLINE_NODEBUG uint32_t nameSize() const noexcept { return _name.size(); } + + //! Returns links associated with this label. + ASMJIT_INLINE_NODEBUG LabelLink* links() const noexcept { return _links; } + + //! Tests whether the label is bound. + ASMJIT_INLINE_NODEBUG bool isBound() const noexcept { return _section != nullptr; } + //! Tests whether the label is bound to a the given `sectionId`. + ASMJIT_INLINE_NODEBUG bool isBoundTo(Section* section) const noexcept { return _section == section; } + + //! Returns the label offset (only useful if the label is bound). + ASMJIT_INLINE_NODEBUG uint64_t offset() const noexcept { return _offset; } + + //! Returns the hash-value of label's name and its parent label (if any). + //! + //! Label hash is calculated as `HASH(Name) ^ ParentId`. The hash function is implemented in `Support::hashString()` + //! and `Support::hashRound()`. + ASMJIT_INLINE_NODEBUG uint32_t hashCode() const noexcept { return _hashCode; } + + //! \} +}; + +//! Holds assembled code and data (including sections, labels, and relocation information). +//! +//! CodeHolder connects emitters with their targets. It provides them interface that can be used to query information +//! about the target environment (architecture, etc...) and API to create labels, sections, relocations, and to write +//! data to a \ref CodeBuffer, which is always part of \ref Section. More than one emitter can be attached to a single +//! CodeHolder instance at a time, which is used in practice +//! +//! CodeHolder provides interface for all emitter types. Assemblers use CodeHolder to write into \ref CodeBuffer, and +//! higher level emitters like Builder and Compiler use CodeHolder to manage labels and sections so higher level code +//! can be serialized to Assembler by \ref BaseEmitter::finalize() and \ref BaseBuilder::serializeTo(). +//! +//! In order to use CodeHolder, it must be first initialized by \ref init(). After the CodeHolder has been successfully +//! initialized it can be used to hold assembled code, sections, labels, relocations, and to attach / detach code +//! emitters. After the end of code generation it can be used to query physical locations of labels and to relocate +//! the assembled code into the right address. +//! +//! \note \ref CodeHolder has an ability to attach an \ref ErrorHandler, however, the error handler is not triggered +//! by \ref CodeHolder itself, it's instead propagated to all emitters that attach to it. +class CodeHolder { +public: + ASMJIT_NONCOPYABLE(CodeHolder) + + //! \name Members + //! \{ + + //! Environment information. + Environment _environment; + //! CPU features of the target architecture. + CpuFeatures _cpuFeatures; + //! Base address or \ref Globals::kNoBaseAddress. + uint64_t _baseAddress; + + //! Attached `Logger`, used by all consumers. + Logger* _logger; + //! Attached `ErrorHandler`. + ErrorHandler* _errorHandler; + + //! Code zone (used to allocate core structures). + Zone _zone; + //! Zone allocator, used to manage internal containers. + ZoneAllocator _allocator; + + //! Attached emitters. + ZoneVector _emitters; + //! Section entries. + ZoneVector _sections; + //! Section entries sorted by section order and then section id. + ZoneVector _sectionsByOrder; + //! Label entries. + ZoneVector _labelEntries; + //! Relocation entries. + ZoneVector _relocations; + //! Label name -> LabelEntry (only named labels). + ZoneHash _namedLabels; + + //! Count of label links, which are not resolved. + size_t _unresolvedLinkCount; + //! Pointer to an address table section (or null if this section doesn't exist). + Section* _addressTableSection; + //! Address table entries. + ZoneTree _addressTableEntries; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates an uninitialized CodeHolder (you must init() it before it can be used). + //! + //! An optional `temporary` argument can be used to initialize the first block of \ref Zone that the CodeHolder + //! uses into a temporary memory provided by the user. + ASMJIT_API explicit CodeHolder(const Support::Temporary* temporary = nullptr) noexcept; + + //! \overload + ASMJIT_INLINE_NODEBUG explicit CodeHolder(const Support::Temporary& temporary) noexcept + : CodeHolder(&temporary) {} + + //! Destroys the CodeHolder and frees all resources it has allocated. + ASMJIT_API ~CodeHolder() noexcept; + + //! Tests whether the `CodeHolder` has been initialized. + //! + //! Emitters can be only attached to initialized `CodeHolder` instances. + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _environment.isInitialized(); } + + //! Initializes CodeHolder to hold code described by the given `environment` and `baseAddress`. + ASMJIT_API Error init(const Environment& environment, uint64_t baseAddress = Globals::kNoBaseAddress) noexcept; + //! Initializes CodeHolder to hold code described by the given `environment`, `cpuFeatures`, and `baseAddress`. + ASMJIT_API Error init(const Environment& environment, const CpuFeatures& cpuFeatures, uint64_t baseAddress = Globals::kNoBaseAddress) noexcept; + //! Detaches all code-generators attached and resets the `CodeHolder`. + ASMJIT_API void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept; + + //! \} + + //! \name Attach & Detach + //! \{ + + //! Attaches an emitter to this `CodeHolder`. + ASMJIT_API Error attach(BaseEmitter* emitter) noexcept; + //! Detaches an emitter from this `CodeHolder`. + ASMJIT_API Error detach(BaseEmitter* emitter) noexcept; + + //! \} + + //! \name Allocators + //! \{ + + //! Returns the allocator that the `CodeHolder` uses. + //! + //! \note This should be only used for AsmJit's purposes. Code holder uses arena allocator to allocate everything, + //! so anything allocated through this allocator will be invalidated by \ref CodeHolder::reset() or by CodeHolder's + //! destructor. + ASMJIT_INLINE_NODEBUG ZoneAllocator* allocator() const noexcept { return const_cast(&_allocator); } + + //! \} + + //! \name Code & Architecture + //! \{ + + //! Returns the target environment information. + ASMJIT_INLINE_NODEBUG const Environment& environment() const noexcept { return _environment; } + + //! Returns the target architecture. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return environment().arch(); } + //! Returns the target sub-architecture. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return environment().subArch(); } + + //! Returns the minimum CPU features of the target architecture. + ASMJIT_INLINE_NODEBUG const CpuFeatures& cpuFeatures() const noexcept { return _cpuFeatures; } + + //! Tests whether a static base-address is set. + ASMJIT_INLINE_NODEBUG bool hasBaseAddress() const noexcept { return _baseAddress != Globals::kNoBaseAddress; } + //! Returns a static base-address or \ref Globals::kNoBaseAddress, if not set. + ASMJIT_INLINE_NODEBUG uint64_t baseAddress() const noexcept { return _baseAddress; } + + //! \} + + //! \name Emitters + //! \{ + + //! Returns a vector of attached emitters. + ASMJIT_INLINE_NODEBUG const ZoneVector& emitters() const noexcept { return _emitters; } + + //! \} + + //! \name Logging + //! \{ + + //! Returns the attached logger. + ASMJIT_INLINE_NODEBUG Logger* logger() const noexcept { return _logger; } + //! Attaches a `logger` to CodeHolder and propagates it to all attached emitters. + ASMJIT_API void setLogger(Logger* logger) noexcept; + //! Resets the logger to none. + ASMJIT_INLINE_NODEBUG void resetLogger() noexcept { setLogger(nullptr); } + + //! \name Error Handling + //! \{ + + //! Tests whether the CodeHolder has an attached error handler, see \ref ErrorHandler. + ASMJIT_INLINE_NODEBUG bool hasErrorHandler() const noexcept { return _errorHandler != nullptr; } + //! Returns the attached error handler. + ASMJIT_INLINE_NODEBUG ErrorHandler* errorHandler() const noexcept { return _errorHandler; } + //! Attach an error handler to this `CodeHolder`. + ASMJIT_API void setErrorHandler(ErrorHandler* errorHandler) noexcept; + //! Resets the error handler to none. + ASMJIT_INLINE_NODEBUG void resetErrorHandler() noexcept { setErrorHandler(nullptr); } + + //! \} + + //! \name Code Buffer + //! \{ + + //! Makes sure that at least `n` bytes can be added to CodeHolder's buffer `cb`. + //! + //! \note The buffer `cb` must be managed by `CodeHolder` - otherwise the behavior of the function is undefined. + ASMJIT_API Error growBuffer(CodeBuffer* cb, size_t n) noexcept; + + //! Reserves the size of `cb` to at least `n` bytes. + //! + //! \note The buffer `cb` must be managed by `CodeHolder` - otherwise the behavior of the function is undefined. + ASMJIT_API Error reserveBuffer(CodeBuffer* cb, size_t n) noexcept; + + //! \} + + //! \name Sections + //! \{ + + //! Returns an array of `Section*` records. + ASMJIT_INLINE_NODEBUG const ZoneVector& sections() const noexcept { return _sections; } + //! Returns an array of `Section*` records sorted according to section order first, then section id. + ASMJIT_INLINE_NODEBUG const ZoneVector& sectionsByOrder() const noexcept { return _sectionsByOrder; } + //! Returns the number of sections. + ASMJIT_INLINE_NODEBUG uint32_t sectionCount() const noexcept { return _sections.size(); } + + //! Tests whether the given `sectionId` is valid. + ASMJIT_INLINE_NODEBUG bool isSectionValid(uint32_t sectionId) const noexcept { return sectionId < _sections.size(); } + + //! Creates a new section and return its pointer in `sectionOut`. + //! + //! Returns `Error`, does not report a possible error to `ErrorHandler`. + ASMJIT_API Error newSection(Section** sectionOut, const char* name, size_t nameSize = SIZE_MAX, SectionFlags flags = SectionFlags::kNone, uint32_t alignment = 1, int32_t order = 0) noexcept; + + //! Returns a section entry of the given index. + ASMJIT_INLINE_NODEBUG Section* sectionById(uint32_t sectionId) const noexcept { return _sections[sectionId]; } + + //! Returns section-id that matches the given `name`. + //! + //! If there is no such section `Section::kInvalidId` is returned. + ASMJIT_API Section* sectionByName(const char* name, size_t nameSize = SIZE_MAX) const noexcept; + + //! Returns '.text' section (section that commonly represents code). + //! + //! \note Text section is always the first section in \ref CodeHolder::sections() array. + ASMJIT_INLINE_NODEBUG Section* textSection() const noexcept { return _sections[0]; } + + //! Tests whether '.addrtab' section exists. + ASMJIT_INLINE_NODEBUG bool hasAddressTable() const noexcept { return _addressTableSection != nullptr; } + + //! Returns '.addrtab' section. + //! + //! This section is used exclusively by AsmJit to store absolute 64-bit + //! addresses that cannot be encoded in instructions like 'jmp' or 'call'. + //! + //! \note This section is created on demand, the returned pointer can be null. + ASMJIT_INLINE_NODEBUG Section* addressTableSection() const noexcept { return _addressTableSection; } + + //! Ensures that '.addrtab' section exists (creates it if it doesn't) and + //! returns it. Can return `nullptr` on out of memory condition. + ASMJIT_API Section* ensureAddressTableSection() noexcept; + + //! Used to add an address to an address table. + //! + //! This implicitly calls `ensureAddressTableSection()` and then creates `AddressTableEntry` that is inserted + //! to `_addressTableEntries`. If the address already exists this operation does nothing as the same addresses + //! use the same slot. + //! + //! This function should be considered internal as it's used by assemblers to insert an absolute address into the + //! address table. Inserting address into address table without creating a particular relocation entry makes no sense. + ASMJIT_API Error addAddressToAddressTable(uint64_t address) noexcept; + + //! \} + + //! \name Labels & Symbols + //! \{ + + //! Returns array of `LabelEntry*` records. + ASMJIT_INLINE_NODEBUG const ZoneVector& labelEntries() const noexcept { return _labelEntries; } + + //! Returns number of labels created. + ASMJIT_INLINE_NODEBUG uint32_t labelCount() const noexcept { return _labelEntries.size(); } + + //! Tests whether the label having `id` is valid (i.e. created by `newLabelEntry()`). + ASMJIT_INLINE_NODEBUG bool isLabelValid(uint32_t labelId) const noexcept { + return labelId < _labelEntries.size(); + } + + //! Tests whether the `label` is valid (i.e. created by `newLabelEntry()`). + ASMJIT_INLINE_NODEBUG bool isLabelValid(const Label& label) const noexcept { + return label.id() < _labelEntries.size(); + } + + //! \overload + ASMJIT_INLINE_NODEBUG bool isLabelBound(uint32_t labelId) const noexcept { + return isLabelValid(labelId) && _labelEntries[labelId]->isBound(); + } + + //! Tests whether the `label` is already bound. + //! + //! Returns `false` if the `label` is not valid. + ASMJIT_INLINE_NODEBUG bool isLabelBound(const Label& label) const noexcept { + return isLabelBound(label.id()); + } + + //! Returns LabelEntry of the given label `id`. + ASMJIT_INLINE_NODEBUG LabelEntry* labelEntry(uint32_t labelId) const noexcept { + return isLabelValid(labelId) ? _labelEntries[labelId] : static_cast(nullptr); + } + + //! Returns LabelEntry of the given `label`. + ASMJIT_INLINE_NODEBUG LabelEntry* labelEntry(const Label& label) const noexcept { + return labelEntry(label.id()); + } + + //! Returns offset of a `Label` by its `labelId`. + //! + //! The offset returned is relative to the start of the section. Zero offset is returned for unbound labels, + //! which is their initial offset value. + ASMJIT_INLINE_NODEBUG uint64_t labelOffset(uint32_t labelId) const noexcept { + ASMJIT_ASSERT(isLabelValid(labelId)); + return _labelEntries[labelId]->offset(); + } + + //! \overload + ASMJIT_INLINE_NODEBUG uint64_t labelOffset(const Label& label) const noexcept { + return labelOffset(label.id()); + } + + //! Returns offset of a label by it's `labelId` relative to the base offset. + //! + //! \remarks The offset of the section where the label is bound must be valid in order to use this function, + //! otherwise the value returned will not be reliable. + inline uint64_t labelOffsetFromBase(uint32_t labelId) const noexcept { + ASMJIT_ASSERT(isLabelValid(labelId)); + const LabelEntry* le = _labelEntries[labelId]; + return (le->isBound() ? le->section()->offset() : uint64_t(0)) + le->offset(); + } + + //! \overload + inline uint64_t labelOffsetFromBase(const Label& label) const noexcept { + return labelOffsetFromBase(label.id()); + } + + //! Creates a new anonymous label and return its id in `idOut`. + //! + //! Returns `Error`, does not report error to `ErrorHandler`. + ASMJIT_API Error newLabelEntry(LabelEntry** entryOut) noexcept; + + //! Creates a new named \ref LabelEntry of the given label `type`. + //! + //! \param entryOut Where to store the created \ref LabelEntry. + //! \param name The name of the label. + //! \param nameSize The length of `name` argument, or `SIZE_MAX` if `name` is a null terminated string, which + //! means that the `CodeHolder` will use `strlen()` to determine the length. + //! \param type The type of the label to create, see \ref LabelType. + //! \param parentId Parent id of a local label, otherwise it must be \ref Globals::kInvalidId. + //! \retval Always returns \ref Error, does not report a possible error to the attached \ref ErrorHandler. + //! + //! AsmJit has a support for local labels (\ref LabelType::kLocal) which require a parent label id (parentId). + //! The names of local labels can conflict with names of other local labels that have a different parent. In + //! addition, AsmJit supports named anonymous labels, which are useful only for debugging purposes as the + //! anonymous name will have a name, which will be formatted, but the label itself cannot be queried by such + //! name. + ASMJIT_API Error newNamedLabelEntry(LabelEntry** entryOut, const char* name, size_t nameSize, LabelType type, uint32_t parentId = Globals::kInvalidId) noexcept; + + //! Returns a label by name. + //! + //! If the named label doesn't a default constructed \ref Label is returned, + //! which has its id set to \ref Globals::kInvalidId. + ASMJIT_INLINE_NODEBUG Label labelByName(const char* name, size_t nameSize = SIZE_MAX, uint32_t parentId = Globals::kInvalidId) noexcept { + return Label(labelIdByName(name, nameSize, parentId)); + } + + //! Returns a label id by name. + //! + //! If the named label doesn't exist \ref Globals::kInvalidId is returned. + ASMJIT_API uint32_t labelIdByName(const char* name, size_t nameSize = SIZE_MAX, uint32_t parentId = Globals::kInvalidId) noexcept; + + //! Tests whether there are any unresolved label links. + ASMJIT_INLINE_NODEBUG bool hasUnresolvedLinks() const noexcept { return _unresolvedLinkCount != 0; } + //! Returns the number of label links, which are unresolved. + ASMJIT_INLINE_NODEBUG size_t unresolvedLinkCount() const noexcept { return _unresolvedLinkCount; } + + //! Creates a new label-link used to store information about yet unbound labels. + //! + //! Returns `null` if the allocation failed. + ASMJIT_API LabelLink* newLabelLink(LabelEntry* le, uint32_t sectionId, size_t offset, intptr_t rel, const OffsetFormat& format) noexcept; + + //! Resolves cross-section links (`LabelLink`) associated with each label that was used as a destination in code + //! of a different section. It's only useful to people that use multiple sections as it will do nothing if the code + //! only contains a single section in which cross-section links are not possible. + ASMJIT_API Error resolveUnresolvedLinks() noexcept; + + //! Binds a label to a given `sectionId` and `offset` (relative to start of the section). + //! + //! This function is generally used by `BaseAssembler::bind()` to do the heavy lifting. + ASMJIT_API Error bindLabel(const Label& label, uint32_t sectionId, uint64_t offset) noexcept; + + //! \} + + //! \name Relocations + //! \{ + + //! Tests whether the code contains relocation entries. + ASMJIT_INLINE_NODEBUG bool hasRelocEntries() const noexcept { return !_relocations.empty(); } + //! Returns array of `RelocEntry*` records. + ASMJIT_INLINE_NODEBUG const ZoneVector& relocEntries() const noexcept { return _relocations; } + + //! Returns a RelocEntry of the given `id`. + ASMJIT_INLINE_NODEBUG RelocEntry* relocEntry(uint32_t id) const noexcept { return _relocations[id]; } + + //! Creates a new relocation entry of type `relocType`. + //! + //! Additional fields can be set after the relocation entry was created. + ASMJIT_API Error newRelocEntry(RelocEntry** dst, RelocType relocType) noexcept; + + //! \} + + //! \name Utilities + //! \{ + + //! Flattens all sections by recalculating their offsets, starting at 0. + //! + //! \note This should never be called more than once. + ASMJIT_API Error flatten() noexcept; + + //! Returns computed the size of code & data of all sections. + //! + //! \note All sections will be iterated over and the code size returned would represent the minimum code size of + //! all combined sections after applying minimum alignment. Code size may decrease after calling `flatten()` and + //! `relocateToBase()`. + ASMJIT_API size_t codeSize() const noexcept; + + //! Relocates the code to the given `baseAddress`. + //! + //! \param baseAddress Absolute base address where the code will be relocated to. Please note that nothing is + //! copied to such base address, it's just an absolute value used by the relocation code to resolve all stored + //! relocations. + //! + //! \note This should never be called more than once. + ASMJIT_API Error relocateToBase(uint64_t baseAddress) noexcept; + + //! Copies a single section into `dst`. + ASMJIT_API Error copySectionData(void* dst, size_t dstSize, uint32_t sectionId, CopySectionFlags copyFlags = CopySectionFlags::kNone) noexcept; + + //! Copies all sections into `dst`. + //! + //! This should only be used if the data was flattened and there are no gaps between the sections. The `dstSize` + //! is always checked and the copy will never write anything outside the provided buffer. + ASMJIT_API Error copyFlattenedData(void* dst, size_t dstSize, CopySectionFlags copyFlags = CopySectionFlags::kNone) noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CODEHOLDER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/compiler.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..7d4b47c45971a87922c66f22508dffc99ec07793 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/compiler.h @@ -0,0 +1,741 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_COMPILER_H_INCLUDED +#define ASMJIT_CORE_COMPILER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_COMPILER + +#include "../core/assembler.h" +#include "../core/builder.h" +#include "../core/constpool.h" +#include "../core/compilerdefs.h" +#include "../core/func.h" +#include "../core/inst.h" +#include "../core/operand.h" +#include "../core/support.h" +#include "../core/zone.h" +#include "../core/zonevector.h" + +ASMJIT_BEGIN_NAMESPACE + +class JumpAnnotation; +class JumpNode; +class FuncNode; +class FuncRetNode; +class InvokeNode; + +//! \addtogroup asmjit_compiler +//! \{ + +//! Code emitter that uses virtual registers and performs register allocation. +//! +//! Compiler is a high-level code-generation tool that provides register allocation and automatic handling of function +//! calling conventions. It was primarily designed for merging multiple parts of code into a function without worrying +//! about registers and function calling conventions. +//! +//! BaseCompiler can be used, with a minimum effort, to handle 32-bit and 64-bit code generation within a single code +//! base. +//! +//! BaseCompiler is based on BaseBuilder and contains all the features it provides. It means that the code it stores +//! can be modified (removed, added, injected) and analyzed. When the code is finalized the compiler can emit the code +//! into an Assembler to translate the abstract representation into a machine code. +//! +//! Check out architecture specific compilers for more details and examples: +//! +//! - \ref x86::Compiler - X86/X64 compiler implementation. +//! - \ref a64::Compiler - AArch64 compiler implementation. +class ASMJIT_VIRTAPI BaseCompiler : public BaseBuilder { +public: + ASMJIT_NONCOPYABLE(BaseCompiler) + typedef BaseBuilder Base; + + //! \name Members + //! \{ + + //! Current function. + FuncNode* _func; + //! Allocates `VirtReg` objects. + Zone _vRegZone; + //! Stores array of `VirtReg` pointers. + ZoneVector _vRegArray; + //! Stores jump annotations. + ZoneVector _jumpAnnotations; + + //! Local and global constant pools. + //! + //! Local constant pool is flushed with each function, global constant pool is flushed only by \ref finalize(). + ConstPoolNode* _constPools[2]; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseCompiler` instance. + ASMJIT_API BaseCompiler() noexcept; + //! Destroys the `BaseCompiler` instance. + ASMJIT_API ~BaseCompiler() noexcept override; + + //! \} + + //! \name Function Management + //! \{ + + //! Creates a new \ref FuncNode. + ASMJIT_API Error newFuncNode(FuncNode** ASMJIT_NONNULL(out), const FuncSignature& signature); + //! Creates a new \ref FuncNode adds it to the instruction stream. + ASMJIT_API Error addFuncNode(FuncNode** ASMJIT_NONNULL(out), const FuncSignature& signature); + + //! Creates a new \ref FuncRetNode. + ASMJIT_API Error newFuncRetNode(FuncRetNode** ASMJIT_NONNULL(out), const Operand_& o0, const Operand_& o1); + //! Creates a new \ref FuncRetNode and adds it to the instruction stream. + ASMJIT_API Error addFuncRetNode(FuncRetNode** ASMJIT_NONNULL(out), const Operand_& o0, const Operand_& o1); + + //! Returns the current function. + ASMJIT_INLINE_NODEBUG FuncNode* func() const noexcept { return _func; } + + //! Creates a new \ref FuncNode with the given `signature` and returns it. + inline FuncNode* newFunc(const FuncSignature& signature) { + FuncNode* node; + newFuncNode(&node, signature); + return node; + } + + //! Creates a new \ref FuncNode with the given `signature`, adds it to the instruction stream by using + //! the \ref addFunc(FuncNode*) overload, and returns it. + inline FuncNode* addFunc(const FuncSignature& signature) { + FuncNode* node; + addFuncNode(&node, signature); + return node; + } + + //! Adds a function `node` to the instruction stream. + ASMJIT_API FuncNode* addFunc(FuncNode* ASMJIT_NONNULL(func)); + //! Emits a sentinel that marks the end of the current function. + ASMJIT_API Error endFunc(); + +#if !defined(ASMJIT_NO_DEPRECATED) + inline Error _setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg); + + //! Sets a function argument at `argIndex` to `reg`. + ASMJIT_DEPRECATED("Setting arguments through Compiler is deprecated, use FuncNode->setArg() instead") + inline Error setArg(size_t argIndex, const BaseReg& reg) { return _setArg(argIndex, 0, reg); } + + //! Sets a function argument at `argIndex` at `valueIndex` to `reg`. + ASMJIT_DEPRECATED("Setting arguments through Compiler is deprecated, use FuncNode->setArg() instead") + inline Error setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg) { return _setArg(argIndex, valueIndex, reg); } +#endif + + inline Error addRet(const Operand_& o0, const Operand_& o1) { + FuncRetNode* node; + return addFuncRetNode(&node, o0, o1); + } + + //! \} + + //! \name Function Invocation + //! \{ + + //! Creates a new \ref InvokeNode. + ASMJIT_API Error newInvokeNode(InvokeNode** ASMJIT_NONNULL(out), InstId instId, const Operand_& o0, const FuncSignature& signature); + //! Creates a new \ref InvokeNode and adds it to the instruction stream. + ASMJIT_API Error addInvokeNode(InvokeNode** ASMJIT_NONNULL(out), InstId instId, const Operand_& o0, const FuncSignature& signature); + + //! \} + + //! \name Virtual Registers + //! \{ + + //! Creates a new virtual register representing the given `typeId` and `signature`. + //! + //! \note This function is public, but it's not generally recommended to be used by AsmJit users, use architecture + //! specific `newReg()` functionality instead or functions like \ref _newReg() and \ref _newRegFmt(). + ASMJIT_API Error newVirtReg(VirtReg** ASMJIT_NONNULL(out), TypeId typeId, OperandSignature signature, const char* name); + + //! Creates a new virtual register of the given `typeId` and stores it to `out` operand. + ASMJIT_API Error _newReg(BaseReg* ASMJIT_NONNULL(out), TypeId typeId, const char* name = nullptr); + + //! Creates a new virtual register of the given `typeId` and stores it to `out` operand. + //! + //! \note This version accepts a snprintf() format `fmt` followed by a variadic arguments. + ASMJIT_API Error _newRegFmt(BaseReg* ASMJIT_NONNULL(out), TypeId typeId, const char* fmt, ...); + //! \overload + inline Error _newRegFmt(BaseReg* ASMJIT_NONNULL(out), TypeId typeId) { return _newRegFmt(out, typeId, nullptr); } + + //! Creates a new virtual register compatible with the provided reference register `ref`. + ASMJIT_API Error _newReg(BaseReg* ASMJIT_NONNULL(out), const BaseReg& ref, const char* name = nullptr); + + //! Creates a new virtual register compatible with the provided reference register `ref`. + //! + //! \note This version accepts a snprintf() format `fmt` followed by a variadic arguments. + ASMJIT_API Error _newRegFmt(BaseReg* ASMJIT_NONNULL(out), const BaseReg& ref, const char* fmt, ...); + + //! Tests whether the given `id` is a valid virtual register id. + ASMJIT_INLINE_NODEBUG bool isVirtIdValid(uint32_t id) const noexcept { + uint32_t index = Operand::virtIdToIndex(id); + return index < _vRegArray.size(); + } + //! Tests whether the given `reg` is a virtual register having a valid id. + ASMJIT_INLINE_NODEBUG bool isVirtRegValid(const BaseReg& reg) const noexcept { + return isVirtIdValid(reg.id()); + } + + //! Returns \ref VirtReg associated with the given `id`. + inline VirtReg* virtRegById(uint32_t id) const noexcept { + ASMJIT_ASSERT(isVirtIdValid(id)); + return _vRegArray[Operand::virtIdToIndex(id)]; + } + + //! Returns \ref VirtReg associated with the given `reg`. + ASMJIT_INLINE_NODEBUG VirtReg* virtRegByReg(const BaseReg& reg) const noexcept { return virtRegById(reg.id()); } + + //! Returns \ref VirtReg associated with the given virtual register `index`. + //! + //! \note This is not the same as virtual register id. The conversion between id and its index is implemented + //! by \ref Operand_::virtIdToIndex() and \ref Operand_::indexToVirtId() functions. + ASMJIT_INLINE_NODEBUG VirtReg* virtRegByIndex(uint32_t index) const noexcept { return _vRegArray[index]; } + + //! Returns an array of all virtual registers managed by the Compiler. + ASMJIT_INLINE_NODEBUG const ZoneVector& virtRegs() const noexcept { return _vRegArray; } + + //! \name Stack + //! \{ + + //! Creates a new stack of the given `size` and `alignment` and stores it to `out`. + //! + //! \note `name` can be used to give the stack a name, for debugging purposes. + ASMJIT_API Error _newStack(BaseMem* ASMJIT_NONNULL(out), uint32_t size, uint32_t alignment, const char* name = nullptr); + + //! Updates the stack size of a stack created by `_newStack()` by its `virtId`. + ASMJIT_API Error setStackSize(uint32_t virtId, uint32_t newSize, uint32_t newAlignment = 0); + + //! Updates the stack size of a stack created by `_newStack()`. + ASMJIT_INLINE_NODEBUG Error setStackSize(const BaseMem& mem, uint32_t newSize, uint32_t newAlignment = 0) { + return setStackSize(mem.id(), newSize, newAlignment); + } + + //! \} + + //! \name Constants + //! \{ + + //! Creates a new constant of the given `scope` (see \ref ConstPoolScope). + //! + //! This function adds a constant of the given `size` to the built-in \ref ConstPool and stores the reference to that + //! constant to the `out` operand. + ASMJIT_API Error _newConst(BaseMem* ASMJIT_NONNULL(out), ConstPoolScope scope, const void* data, size_t size); + + //! \} + + //! \name Miscellaneous + //! \{ + + //! Rename the given virtual register `reg` to a formatted string `fmt`. + ASMJIT_API void rename(const BaseReg& reg, const char* fmt, ...); + + //! \} + + //! \name Jump Annotations + //! \{ + + ASMJIT_INLINE_NODEBUG const ZoneVector& jumpAnnotations() const noexcept { + return _jumpAnnotations; + } + + ASMJIT_API Error newJumpNode(JumpNode** ASMJIT_NONNULL(out), InstId instId, InstOptions instOptions, const Operand_& o0, JumpAnnotation* annotation); + ASMJIT_API Error emitAnnotatedJump(InstId instId, const Operand_& o0, JumpAnnotation* annotation); + + //! Returns a new `JumpAnnotation` instance, which can be used to aggregate possible targets of a jump where the + //! target is not a label, for example to implement jump tables. + ASMJIT_API JumpAnnotation* newJumpAnnotation(); + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! Jump annotation used to annotate jumps. +//! +//! \ref BaseCompiler allows to emit jumps where the target is either register or memory operand. Such jumps cannot be +//! trivially inspected, so instead of doing heuristics AsmJit allows to annotate such jumps with possible targets. +//! Register allocator then uses the annotation to construct control-flow, which is then used by liveness analysis and +//! other tools to prepare ground for register allocation. +class JumpAnnotation { +public: + ASMJIT_NONCOPYABLE(JumpAnnotation) + + //! \name Members + //! \{ + + //! Compiler that owns this JumpAnnotation. + BaseCompiler* _compiler; + //! Annotation identifier. + uint32_t _annotationId; + //! Vector of label identifiers, see \ref labelIds(). + ZoneVector _labelIds; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG JumpAnnotation(BaseCompiler* ASMJIT_NONNULL(compiler), uint32_t annotationId) noexcept + : _compiler(compiler), + _annotationId(annotationId) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the compiler that owns this JumpAnnotation. + ASMJIT_INLINE_NODEBUG BaseCompiler* compiler() const noexcept { return _compiler; } + //! Returns the annotation id. + ASMJIT_INLINE_NODEBUG uint32_t annotationId() const noexcept { return _annotationId; } + //! Returns a vector of label identifiers that lists all targets of the jump. + ASMJIT_INLINE_NODEBUG const ZoneVector& labelIds() const noexcept { return _labelIds; } + + //! Tests whether the given `label` is a target of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG bool hasLabel(const Label& label) const noexcept { return hasLabelId(label.id()); } + //! Tests whether the given `labelId` is a target of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG bool hasLabelId(uint32_t labelId) const noexcept { return _labelIds.contains(labelId); } + + //! \} + + //! \name Annotation Building API + //! \{ + + //! Adds the `label` to the list of targets of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG Error addLabel(const Label& label) noexcept { return addLabelId(label.id()); } + //! Adds the `labelId` to the list of targets of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG Error addLabelId(uint32_t labelId) noexcept { return _labelIds.append(&_compiler->_allocator, labelId); } + + //! \} +}; + +//! Jump instruction with \ref JumpAnnotation. +//! +//! \note This node should be only used to represent jump where the jump target cannot be deduced by examining +//! instruction operands. For example if the jump target is register or memory location. This pattern is often +//! used to perform indirect jumps that use jump table, e.g. to implement `switch{}` statement. +class JumpNode : public InstNodeWithOperands { +public: + ASMJIT_NONCOPYABLE(JumpNode) + + //! \name Members + //! \{ + + JumpAnnotation* _annotation; + + //! \} + + //! \name Construction & Destruction + //! \{ + + inline JumpNode(BaseCompiler* ASMJIT_NONNULL(cc), InstId instId, InstOptions options, uint32_t opCount, JumpAnnotation* annotation) noexcept + : InstNodeWithOperands(cc, instId, options, opCount), + _annotation(annotation) { + setType(NodeType::kJump); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether this JumpNode has associated a \ref JumpAnnotation. + ASMJIT_INLINE_NODEBUG bool hasAnnotation() const noexcept { return _annotation != nullptr; } + //! Returns the \ref JumpAnnotation associated with this jump, or `nullptr`. + ASMJIT_INLINE_NODEBUG JumpAnnotation* annotation() const noexcept { return _annotation; } + //! Sets the \ref JumpAnnotation associated with this jump to `annotation`. + ASMJIT_INLINE_NODEBUG void setAnnotation(JumpAnnotation* annotation) noexcept { _annotation = annotation; } + + //! \} +}; + +//! Function node represents a function used by \ref BaseCompiler. +//! +//! A function is composed of the following: +//! +//! - Function entry, \ref FuncNode acts as a label, so the entry is implicit. To get the entry, simply use +//! \ref FuncNode::label(), which is the same as \ref LabelNode::label(). +//! +//! - Function exit, which is represented by \ref FuncNode::exitNode(). A helper function +//! \ref FuncNode::exitLabel() exists and returns an exit label instead of node. +//! +//! - Function \ref FuncNode::endNode() sentinel. This node marks the end of a function - there should be no +//! code that belongs to the function after this node, but the Compiler doesn't enforce that at the moment. +//! +//! - Function detail, see \ref FuncNode::detail(). +//! +//! - Function frame, see \ref FuncNode::frame(). +//! +//! - Function arguments mapped to virtual registers, see \ref FuncNode::argPacks(). +//! +//! In a node list, the function and its body looks like the following: +//! +//! \code{.unparsed} +//! [...] - Anything before the function. +//! +//! [FuncNode] - Entry point of the function, acts as a label as well. +//! - Prolog inserted by the register allocator. +//! {...} - Function body - user code basically. +//! [ExitLabel] - Exit label +//! - Epilog inserted by the register allocator. +//! - Return inserted by the register allocator. +//! {...} - Can contain data or user code (error handling, special cases, ...). +//! [FuncEnd] - End sentinel +//! +//! [...] - Anything after the function. +//! \endcode +//! +//! When a function is added to the instruction stream by \ref BaseCompiler::addFunc() it actually inserts 3 nodes +//! (FuncNode, ExitLabel, and FuncEnd) and sets the current cursor to be FuncNode. When \ref BaseCompiler::endFunc() +//! is called the cursor is set to FuncEnd. This guarantees that user can use ExitLabel as a marker after additional +//! code or data can be placed, which is a common practice. +class FuncNode : public LabelNode { +public: + ASMJIT_NONCOPYABLE(FuncNode) + + //! Arguments pack. + struct ArgPack { + RegOnly _data[Globals::kMaxValuePack]; + + inline void reset() noexcept { + for (size_t valueIndex = 0; valueIndex < Globals::kMaxValuePack; valueIndex++) + _data[valueIndex].reset(); + } + + inline RegOnly& operator[](size_t valueIndex) noexcept { return _data[valueIndex]; } + inline const RegOnly& operator[](size_t valueIndex) const noexcept { return _data[valueIndex]; } + }; + + //! \name Members + //! \{ + + //! Function detail. + FuncDetail _funcDetail; + //! Function frame. + FuncFrame _frame; + //! Function exit label. + LabelNode* _exitNode; + //! Function end (sentinel). + SentinelNode* _end; + //! Argument packs. + ArgPack* _args; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `FuncNode` instance. + //! + //! Always use `BaseCompiler::addFunc()` to create a new `FuncNode`. + inline FuncNode(BaseBuilder* ASMJIT_NONNULL(cb)) noexcept + : LabelNode(cb), + _funcDetail(), + _frame(), + _exitNode(nullptr), + _end(nullptr), + _args(nullptr) { + setType(NodeType::kFunc); + } + + //! \} + + //! \{ + //! \name Accessors + + //! Returns function exit `LabelNode`. + ASMJIT_INLINE_NODEBUG LabelNode* exitNode() const noexcept { return _exitNode; } + //! Returns function exit label. + ASMJIT_INLINE_NODEBUG Label exitLabel() const noexcept { return _exitNode->label(); } + + //! Returns "End of Func" sentinel node. + ASMJIT_INLINE_NODEBUG SentinelNode* endNode() const noexcept { return _end; } + + //! Returns function detail. + ASMJIT_INLINE_NODEBUG FuncDetail& detail() noexcept { return _funcDetail; } + //! Returns function detail. + ASMJIT_INLINE_NODEBUG const FuncDetail& detail() const noexcept { return _funcDetail; } + + //! Returns function frame. + ASMJIT_INLINE_NODEBUG FuncFrame& frame() noexcept { return _frame; } + //! Returns function frame. + ASMJIT_INLINE_NODEBUG const FuncFrame& frame() const noexcept { return _frame; } + + //! Returns function attributes. + ASMJIT_INLINE_NODEBUG FuncAttributes attributes() const noexcept { return _frame.attributes(); } + //! Adds `attrs` to the function attributes. + ASMJIT_INLINE_NODEBUG void addAttributes(FuncAttributes attrs) noexcept { _frame.addAttributes(attrs); } + + //! Returns arguments count. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _funcDetail.argCount(); } + //! Returns argument packs. + ASMJIT_INLINE_NODEBUG ArgPack* argPacks() const noexcept { return _args; } + + //! Tests whether the function has a return value. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return _funcDetail.hasRet(); } + + //! Returns argument pack at `argIndex`. + inline ArgPack& argPack(size_t argIndex) const noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex]; + } + + //! Sets argument at `argIndex`. + inline void setArg(size_t argIndex, const BaseReg& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][0].init(vReg); + } + + //! \overload + inline void setArg(size_t argIndex, const RegOnly& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][0].init(vReg); + } + + //! Sets argument at `argIndex` and `valueIndex`. + inline void setArg(size_t argIndex, size_t valueIndex, const BaseReg& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex].init(vReg); + } + + //! \overload + inline void setArg(size_t argIndex, size_t valueIndex, const RegOnly& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex].init(vReg); + } + + //! Resets argument pack at `argIndex`. + inline void resetArg(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex].reset(); + } + + //! Resets argument pack at `argIndex`. + inline void resetArg(size_t argIndex, size_t valueIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex].reset(); + } + + //! \} +}; + +//! Function return, used by \ref BaseCompiler. +class FuncRetNode : public InstNodeWithOperands { +public: + ASMJIT_NONCOPYABLE(FuncRetNode) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `FuncRetNode` instance. + inline FuncRetNode(BaseBuilder* ASMJIT_NONNULL(cb)) noexcept + : InstNodeWithOperands(cb, BaseInst::kIdAbstract, InstOptions::kNone, 0) { + _any._nodeType = NodeType::kFuncRet; + } + + //! \} +}; + +//! Function invocation, used by \ref BaseCompiler. +class InvokeNode : public InstNodeWithOperands { +public: + ASMJIT_NONCOPYABLE(InvokeNode) + + //! Operand pack provides multiple operands that can be associated with a single return value of function + //! argument. Sometimes this is necessary to express an argument or return value that requires multiple + //! registers, for example 64-bit value in 32-bit mode or passing / returning homogeneous data structures. + struct OperandPack { + //! Operands. + Operand_ _data[Globals::kMaxValuePack]; + + //! Reset the pack by resetting all operands in the pack. + inline void reset() noexcept { + for (size_t valueIndex = 0; valueIndex < Globals::kMaxValuePack; valueIndex++) + _data[valueIndex].reset(); + } + + //! Returns an operand at the given `valueIndex`. + inline Operand& operator[](size_t valueIndex) noexcept { + ASMJIT_ASSERT(valueIndex < Globals::kMaxValuePack); + return _data[valueIndex].as(); + } + + //! Returns an operand at the given `valueIndex` (const). + const inline Operand& operator[](size_t valueIndex) const noexcept { + ASMJIT_ASSERT(valueIndex < Globals::kMaxValuePack); + return _data[valueIndex].as(); + } + }; + + //! \name Members + //! \{ + + //! Function detail. + FuncDetail _funcDetail; + //! Function return value(s). + OperandPack _rets; + //! Function arguments. + OperandPack* _args; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `InvokeNode` instance. + inline InvokeNode(BaseBuilder* ASMJIT_NONNULL(cb), InstId instId, InstOptions options) noexcept + : InstNodeWithOperands(cb, instId, options, 0), + _funcDetail(), + _args(nullptr) { + setType(NodeType::kInvoke); + _resetOps(); + _rets.reset(); + addFlags(NodeFlags::kIsRemovable); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Sets the function signature. + inline Error init(const FuncSignature& signature, const Environment& environment) noexcept { + return _funcDetail.init(signature, environment); + } + + //! Returns the function detail. + ASMJIT_INLINE_NODEBUG FuncDetail& detail() noexcept { return _funcDetail; } + //! Returns the function detail. + ASMJIT_INLINE_NODEBUG const FuncDetail& detail() const noexcept { return _funcDetail; } + + //! Returns the target operand. + ASMJIT_INLINE_NODEBUG Operand& target() noexcept { return op(0); } + //! \overload + ASMJIT_INLINE_NODEBUG const Operand& target() const noexcept { return op(0); } + + //! Returns the number of function return values. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return _funcDetail.hasRet(); } + //! Returns the number of function arguments. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _funcDetail.argCount(); } + + //! Returns operand pack representing function return value(s). + ASMJIT_INLINE_NODEBUG OperandPack& retPack() noexcept { return _rets; } + //! Returns operand pack representing function return value(s). + ASMJIT_INLINE_NODEBUG const OperandPack& retPack() const noexcept { return _rets; } + + //! Returns the return value at the given `valueIndex`. + ASMJIT_INLINE_NODEBUG Operand& ret(size_t valueIndex = 0) noexcept { return _rets[valueIndex]; } + //! \overload + ASMJIT_INLINE_NODEBUG const Operand& ret(size_t valueIndex = 0) const noexcept { return _rets[valueIndex]; } + + //! Returns operand pack representing function return value(s). + inline OperandPack& argPack(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex]; + } + //! \overload + inline const OperandPack& argPack(size_t argIndex) const noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex]; + } + + //! Returns a function argument at the given `argIndex`. + inline Operand& arg(size_t argIndex, size_t valueIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex][valueIndex]; + } + //! \overload + inline const Operand& arg(size_t argIndex, size_t valueIndex) const noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex][valueIndex]; + } + + //! Sets the function return value at `i` to `op`. + inline void _setRet(size_t valueIndex, const Operand_& op) noexcept { _rets[valueIndex] = op; } + //! Sets the function argument at `i` to `op`. + inline void _setArg(size_t argIndex, size_t valueIndex, const Operand_& op) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex] = op; + } + + //! Sets the function return value at `valueIndex` to `reg`. + ASMJIT_INLINE_NODEBUG void setRet(size_t valueIndex, const BaseReg& reg) noexcept { _setRet(valueIndex, reg); } + + //! Sets the first function argument in a value-pack at `argIndex` to `reg`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, const BaseReg& reg) noexcept { _setArg(argIndex, 0, reg); } + //! Sets the first function argument in a value-pack at `argIndex` to `imm`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, const Imm& imm) noexcept { _setArg(argIndex, 0, imm); } + + //! Sets the function argument at `argIndex` and `valueIndex` to `reg`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg) noexcept { _setArg(argIndex, valueIndex, reg); } + //! Sets the function argument at `argIndex` and `valueIndex` to `imm`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, size_t valueIndex, const Imm& imm) noexcept { _setArg(argIndex, valueIndex, imm); } + + //! \} +}; + +//! Function pass extends \ref Pass with \ref FuncPass::runOnFunction(). +class ASMJIT_VIRTAPI FuncPass : public Pass { +public: + ASMJIT_NONCOPYABLE(FuncPass) + typedef Pass Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API FuncPass(const char* name) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the associated `BaseCompiler`. + ASMJIT_INLINE_NODEBUG BaseCompiler* cc() const noexcept { return static_cast(_cb); } + + //! \} + + //! \name Pass Interface + //! \{ + + //! Calls `runOnFunction()` on each `FuncNode` node found. + ASMJIT_API Error run(Zone* zone, Logger* logger) override; + + //! Called once per `FuncNode`. + ASMJIT_API virtual Error runOnFunction(Zone* zone, Logger* logger, FuncNode* func); + + //! \} +}; + +#if !defined(ASMJIT_NO_DEPRECATED) +inline Error BaseCompiler::_setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg) { + FuncNode* func = _func; + + if (ASMJIT_UNLIKELY(!func)) + return reportError(DebugUtils::errored(kErrorInvalidState)); + + func->setArg(argIndex, valueIndex, reg); + return kErrorOk; +} +#endif + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // !ASMJIT_NO_COMPILER +#endif // ASMJIT_CORE_COMPILER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/compilerdefs.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/compilerdefs.h new file mode 100644 index 0000000000000000000000000000000000000000..e2e74ce6790ea7acd67257f5a9a0bc79bc9c5946 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/compilerdefs.h @@ -0,0 +1,171 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_COMPILERDEFS_H_INCLUDED +#define ASMJIT_CORE_COMPILERDEFS_H_INCLUDED + +#include "../core/api-config.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../core/zonestring.h" + +ASMJIT_BEGIN_NAMESPACE + +class RAWorkReg; + +//! \addtogroup asmjit_compiler +//! \{ + +//! Virtual register data, managed by \ref BaseCompiler. +class VirtReg { +public: + ASMJIT_NONCOPYABLE(VirtReg) + + //! \name Members + //! \{ + + //! Virtual register signature. + OperandSignature _signature {}; + //! Virtual register id. + uint32_t _id = 0; + //! Virtual register size (can be smaller than `_signature._size`). + uint32_t _virtSize = 0; + //! Virtual register alignment (for spilling). + uint8_t _alignment = 0; + //! Type-id. + TypeId _typeId = TypeId::kVoid; + //! Virtual register weight for alloc/spill decisions. + uint8_t _weight = 1; + //! True if this is a fixed register, never reallocated. + uint8_t _isFixed : 1; + //! True if the virtual register is only used as a stack (never accessed as register). + uint8_t _isStack : 1; + //! True if this virtual register has assigned stack offset (can be only valid after register allocation pass). + uint8_t _hasStackSlot : 1; + uint8_t _reservedBits : 5; + + //! Stack offset assigned by the register allocator relative to stack pointer (can be negative as well). + int32_t _stackOffset = 0; + + //! Reserved for future use (padding). + uint32_t _reservedU32 = 0; + + //! Virtual register name (user provided or automatically generated). + ZoneString<16> _name {}; + + // The following members are used exclusively by RAPass. They are initialized when the VirtReg is created to + // null pointers and then changed during RAPass execution. RAPass sets them back to NULL before it returns. + + //! Reference to `RAWorkReg`, used during register allocation. + RAWorkReg* _workReg = nullptr; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG VirtReg(OperandSignature signature, uint32_t id, uint32_t virtSize, uint32_t alignment, TypeId typeId) noexcept + : _signature(signature), + _id(id), + _virtSize(virtSize), + _alignment(uint8_t(alignment)), + _typeId(typeId), + _isFixed(0), + _isStack(0), + _hasStackSlot(0), + _reservedBits(0) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the virtual register id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + + //! Returns the virtual register name. + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name.data(); } + //! Returns the size of the virtual register name. + ASMJIT_INLINE_NODEBUG uint32_t nameSize() const noexcept { return _name.size(); } + + //! Returns a register signature of this virtual register. + ASMJIT_INLINE_NODEBUG OperandSignature signature() const noexcept { return _signature; } + //! Returns a virtual register type (maps to the physical register type as well). + ASMJIT_INLINE_NODEBUG RegType type() const noexcept { return _signature.regType(); } + //! Returns a virtual register group (maps to the physical register group as well). + ASMJIT_INLINE_NODEBUG RegGroup group() const noexcept { return _signature.regGroup(); } + + //! Returns a real size of the register this virtual register maps to. + //! + //! For example if this is a 128-bit SIMD register used for a scalar single precision floating point value then + //! its virtSize would be 4, however, the `regSize` would still say 16 (128-bits), because it's the smallest size + //! of that register type. + ASMJIT_INLINE_NODEBUG uint32_t regSize() const noexcept { return _signature.size(); } + + //! Returns the virtual register size. + //! + //! The virtual register size describes how many bytes the virtual register needs to store its content. It can be + //! smaller than the physical register size, see `regSize()`. + ASMJIT_INLINE_NODEBUG uint32_t virtSize() const noexcept { return _virtSize; } + + //! Returns the virtual register alignment. + ASMJIT_INLINE_NODEBUG uint32_t alignment() const noexcept { return _alignment; } + + //! Returns the virtual register type id. + ASMJIT_INLINE_NODEBUG TypeId typeId() const noexcept { return _typeId; } + + //! Returns the virtual register weight - the register allocator can use it as explicit hint for alloc/spill + //! decisions. + ASMJIT_INLINE_NODEBUG uint32_t weight() const noexcept { return _weight; } + //! Sets the virtual register weight (0 to 255) - the register allocator can use it as explicit hint for + //! alloc/spill decisions and initial bin-packing. + ASMJIT_INLINE_NODEBUG void setWeight(uint32_t weight) noexcept { _weight = uint8_t(weight); } + + //! Returns whether the virtual register is always allocated to a fixed physical register (and never reallocated). + //! + //! \note This is only used for special purposes and it's mostly internal. + ASMJIT_INLINE_NODEBUG bool isFixed() const noexcept { return bool(_isFixed); } + + //! Tests whether the virtual register is in fact a stack that only uses the virtual register id. + //! + //! \note It's an error if a stack is accessed as a register. + ASMJIT_INLINE_NODEBUG bool isStack() const noexcept { return bool(_isStack); } + + //! Tests whether this virtual register (or stack) has assigned a stack offset. + //! + //! If this is a virtual register that was never allocated on stack, it would return false, otherwise if + //! it's a virtual register that was spilled or explicitly allocated stack, the return value would be true. + ASMJIT_INLINE_NODEBUG bool hasStackSlot() const noexcept { return bool(_hasStackSlot); } + + //! Assigns a stack offset of this virtual register to `stackOffset` and sets `_hasStackSlot` to true. + ASMJIT_INLINE_NODEBUG void assignStackSlot(int32_t stackOffset) noexcept { + _hasStackSlot = 1; + _stackOffset = stackOffset; + } + + //! Returns a stack offset associated with a virtual register or explicit stack allocation. + //! + //! \note Always verify that the stack offset has been assigned by calling \ref hasStackSlot(). The return + //! value will be zero when the stack offset was not assigned. + ASMJIT_INLINE_NODEBUG int32_t stackOffset() const noexcept { return _stackOffset; } + + //! Tests whether the virtual register has an associated `RAWorkReg` at the moment. + ASMJIT_INLINE_NODEBUG bool hasWorkReg() const noexcept { return _workReg != nullptr; } + //! Returns an associated RAWorkReg with this virtual register (only valid during register allocation). + ASMJIT_INLINE_NODEBUG RAWorkReg* workReg() const noexcept { return _workReg; } + //! Associates a RAWorkReg with this virtual register (used by register allocator). + ASMJIT_INLINE_NODEBUG void setWorkReg(RAWorkReg* workReg) noexcept { _workReg = workReg; } + //! Reset the RAWorkReg association (used by register allocator). + ASMJIT_INLINE_NODEBUG void resetWorkReg() noexcept { _workReg = nullptr; } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_COMPILERDEFS_H_INCLUDED + diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/constpool.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/constpool.h new file mode 100644 index 0000000000000000000000000000000000000000..673c11d6a93258f7787516e6a589b78607b10e22 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/constpool.h @@ -0,0 +1,261 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CONSTPOOL_H_INCLUDED +#define ASMJIT_CORE_CONSTPOOL_H_INCLUDED + +#include "../core/support.h" +#include "../core/zone.h" +#include "../core/zonetree.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! Constant pool scope. +enum class ConstPoolScope : uint32_t { + //! Local constant, always embedded right after the current function. + kLocal = 0, + //! Global constant, embedded at the end of the currently compiled code. + kGlobal = 1, + + //! Maximum value of `ConstPoolScope`. + kMaxValue = kGlobal +}; + +//! Constant pool. +//! +//! Constant pool is designed to hold 1, 2, 4, 8, 16, 32, and 64 byte constants. It's not designed to hold constants +//! having arbitrary length like strings and arrays. +class ConstPool { +public: + ASMJIT_NONCOPYABLE(ConstPool) + + //! \cond INTERNAL + + //! Index of a given size in const-pool table. + enum Index : uint32_t { + kIndex1 = 0, + kIndex2 = 1, + kIndex4 = 2, + kIndex8 = 3, + kIndex16 = 4, + kIndex32 = 5, + kIndex64 = 6, + kIndexCount = 7 + }; + + //! Zone-allocated const-pool gap created by two differently aligned constants. + struct Gap { + //! Pointer to the next gap + Gap* _next; + //! Offset of the gap. + size_t _offset; + //! Remaining bytes of the gap (basically a gap size). + size_t _size; + }; + + //! Zone-allocated const-pool node. + class Node : public ZoneTreeNodeT { + public: + ASMJIT_NONCOPYABLE(Node) + + //! If this constant is shared with another. + uint32_t _shared : 1; + //! Data offset from the beginning of the pool. + uint32_t _offset; + + ASMJIT_INLINE_NODEBUG Node(size_t offset, bool shared) noexcept + : ZoneTreeNodeT(), + _shared(shared), + _offset(uint32_t(offset)) {} + + ASMJIT_INLINE_NODEBUG void* data() const noexcept { + return static_cast(const_cast(this) + 1); + } + }; + + //! Data comparer used internally. + class Compare { + public: + size_t _dataSize; + + ASMJIT_INLINE_NODEBUG Compare(size_t dataSize) noexcept + : _dataSize(dataSize) {} + + ASMJIT_INLINE_NODEBUG int operator()(const Node& a, const Node& b) const noexcept { + return ::memcmp(a.data(), b.data(), _dataSize); + } + + ASMJIT_INLINE_NODEBUG int operator()(const Node& a, const void* data) const noexcept { + return ::memcmp(a.data(), data, _dataSize); + } + }; + + //! Zone-allocated const-pool tree. + struct Tree { + //! RB tree. + ZoneTree _tree; + //! Size of the tree (number of nodes). + size_t _size; + //! Size of the data. + size_t _dataSize; + + ASMJIT_INLINE_NODEBUG explicit Tree(size_t dataSize = 0) noexcept + : _tree(), + _size(0), + _dataSize(dataSize) {} + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _tree.reset(); + _size = 0; + } + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + + inline void setDataSize(size_t dataSize) noexcept { + ASMJIT_ASSERT(empty()); + _dataSize = dataSize; + } + + ASMJIT_INLINE_NODEBUG Node* get(const void* data) noexcept { + Compare cmp(_dataSize); + return _tree.get(data, cmp); + } + + ASMJIT_INLINE_NODEBUG void insert(Node* node) noexcept { + Compare cmp(_dataSize); + _tree.insert(node, cmp); + _size++; + } + + template + inline void forEach(Visitor& visitor) const noexcept { + Node* node = _tree.root(); + if (!node) return; + + Node* stack[Globals::kMaxTreeHeight]; + size_t top = 0; + + for (;;) { + Node* left = node->left(); + if (left != nullptr) { + ASMJIT_ASSERT(top != Globals::kMaxTreeHeight); + stack[top++] = node; + + node = left; + continue; + } + + for (;;) { + visitor(node); + node = node->right(); + + if (node != nullptr) + break; + + if (top == 0) + return; + + node = stack[--top]; + } + } + } + + static inline Node* _newNode(Zone* zone, const void* data, size_t size, size_t offset, bool shared) noexcept { + Node* node = zone->allocT(sizeof(Node) + size); + if (ASMJIT_UNLIKELY(!node)) return nullptr; + + node = new(Support::PlacementNew{node}) Node(offset, shared); + memcpy(node->data(), data, size); + return node; + } + }; + + //! \endcond + + //! \name Members + //! \{ + + //! Zone allocator. + Zone* _zone; + //! Tree per size. + Tree _tree[kIndexCount]; + //! Gaps per size. + Gap* _gaps[kIndexCount]; + //! Gaps pool + Gap* _gapPool; + + //! Size of the pool (in bytes). + size_t _size; + //! Required pool alignment. + size_t _alignment; + //! Minimum item size in the pool. + size_t _minItemSize; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new constant pool that would use `zone` as a memory allocator. + ASMJIT_API explicit ConstPool(Zone* zone) noexcept; + //! Destroys this constant pool. + ASMJIT_API ~ConstPool() noexcept; + + //! \} + + //! \name Reset + //! \{ + + //! Resets this constant pool and its allocator to `zone`. + ASMJIT_API void reset(Zone* zone) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the constant-pool is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + //! Returns the size of the constant-pool in bytes. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + //! Returns minimum alignment. + ASMJIT_INLINE_NODEBUG size_t alignment() const noexcept { return _alignment; } + //! Returns the minimum size of all items added to the constant pool. + ASMJIT_INLINE_NODEBUG size_t minItemSize() const noexcept { return _minItemSize; } + + //! \} + + //! \name Utilities + //! \{ + + //! Adds a constant to the constant pool. + //! + //! The constant must have known size, which is 1, 2, 4, 8, 16 or 32 bytes. The constant is added to the pool only + //! if it doesn't not exist, otherwise cached value is returned. + //! + //! AsmJit is able to subdivide added constants, so for example if you add 8-byte constant 0x1122334455667788 it + //! will create the following slots: + //! + //! 8-byte: 0x1122334455667788 + //! 4-byte: 0x11223344, 0x55667788 + //! + //! The reason is that when combining MMX/SSE/AVX code some patterns are used frequently. However, AsmJit is not + //! able to reallocate a constant that has been already added. For example if you try to add 4-byte constant and + //! then 8-byte constant having the same 4-byte pattern as the previous one, two independent slots will be used. + ASMJIT_API Error add(const void* data, size_t size, size_t& dstOffset) noexcept; + + //! Fills the destination with the content of this constant pool. + ASMJIT_API void fill(void* dst) const noexcept; +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CONSTPOOL_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/cpuinfo.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/cpuinfo.h new file mode 100644 index 0000000000000000000000000000000000000000..2638146a4cc41f72f404419e3a335d30f7e7d9d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/cpuinfo.h @@ -0,0 +1,1224 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CPUINFO_H_INCLUDED +#define ASMJIT_CORE_CPUINFO_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/environment.h" +#include "../core/globals.h" +#include "../core/string.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! CPU features information. +//! +//! Each feature is represented by a single bit in an embedded bit array. +class CpuFeatures { +public: + //! \name Constants + //! \{ + + //! \cond INTERNAL + enum : uint32_t { + kMaxFeatures = 256, + kNumBitWords = kMaxFeatures / Support::kBitWordSizeInBits + }; + //! \endcond + + //! A word that is used to represents feature bits. + typedef Support::BitWord BitWord; + //! Iterator that can iterate all CPU features set. + typedef Support::BitVectorIterator Iterator; + + typedef Support::Array Bits; + + //! \} + + //! \name Data + //! \{ + + //! CPU features data. + struct Data { + //! \name Members + //! \{ + + //! Data bits. + Bits _bits; + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG bool operator==(const Data& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const Data& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns true if there are no features set. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _bits.aggregate(0) == 0; } + + //! Returns all features as array of bitwords (see \ref Support::BitWord). + ASMJIT_INLINE_NODEBUG BitWord* bits() noexcept { return _bits.data(); } + //! Returns all features as array of bitwords (const). + ASMJIT_INLINE_NODEBUG const BitWord* bits() const noexcept { return _bits.data(); } + + //! Returns the number of BitWords returned by \ref bits(). + ASMJIT_INLINE_NODEBUG size_t bitWordCount() const noexcept { return kNumBitWords; } + + //! Returns \ref Support::BitVectorIterator, that can be used to iterate over all features efficiently. + ASMJIT_INLINE_NODEBUG Iterator iterator() const noexcept { return Iterator(_bits.data(), kNumBitWords); } + + //! Tests whether the feature `featureId` is present. + template + ASMJIT_INLINE_NODEBUG bool has(const FeatureId& featureId) const noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + return bool((_bits[idx] >> bit) & 0x1); + } + + //! \cond NONE + template + ASMJIT_INLINE_NODEBUG bool hasAny(const FeatureId& featureId) const noexcept { + return has(featureId); + } + //! \endcond + + //! Tests whether any feature given is present. + //! + //! \note This is a variadic function template that can be used with multiple features. + template + ASMJIT_INLINE_NODEBUG bool hasAny(const FeatureId& featureId, Args&&... otherFeatureIds) const noexcept { + return bool(unsigned(has(featureId)) | unsigned(hasAny(std::forward(otherFeatureIds)...))); + } + + //! Tests whether all features as defined by `other` are present. + ASMJIT_INLINE_NODEBUG bool hasAll(const Data& other) const noexcept { + uint32_t result = 1; + for (uint32_t i = 0; i < kNumBitWords; i++) + result &= uint32_t((_bits[i] & other._bits[i]) == other._bits[i]); + return bool(result); + } + + //! \} + + //! \name Manipulation + //! \{ + + //! Clears all features set. + ASMJIT_INLINE_NODEBUG void reset() noexcept { _bits.fill(0); } + + //! Adds the given CPU `featureId` to the list of features. + template + ASMJIT_INLINE_NODEBUG void add(const FeatureId& featureId) noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + _bits[idx] |= BitWord(1) << bit; + } + + template + ASMJIT_INLINE_NODEBUG void add(const FeatureId& featureId, Args&&... otherFeatureIds) noexcept { + add(featureId); + add(std::forward(otherFeatureIds)...); + } + + template + ASMJIT_INLINE_NODEBUG void addIf(bool condition, const FeatureId& featureId) noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + _bits[idx] |= BitWord(condition) << bit; + } + + template + ASMJIT_INLINE_NODEBUG void addIf(bool condition, const FeatureId& featureId, Args&&... otherFeatureIds) noexcept { + addIf(condition, featureId); + addIf(condition, std::forward(otherFeatureIds)...); + } + + //! Removes the given CPU `featureId` from the list of features. + template + ASMJIT_INLINE_NODEBUG void remove(const FeatureId& featureId) noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + _bits[idx] &= ~(BitWord(1) << bit); + } + + template + ASMJIT_INLINE_NODEBUG void remove(const FeatureId& featureId, Args&&... otherFeatureIds) noexcept { + remove(featureId); + remove(std::forward(otherFeatureIds)...); + } + + //! Tests whether this CPU features data matches `other`. + ASMJIT_INLINE_NODEBUG bool equals(const Data& other) const noexcept { return _bits == other._bits; } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use CpuFeatures::Data::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const Data& other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} + }; + + //! X86 specific features data. + struct X86 : public Data { + //! X86 CPU feature identifiers. + enum Id : uint8_t { + // @EnumValuesBegin{"enum": "CpuFeatures::X86"}@ + kNone, //!< No feature (never set, used internally). + + kMT, //!< CPU has multi-threading capabilities. + kNX, //!< CPU has Not-Execute-Bit aka DEP (data-execution prevention). + k3DNOW, //!< CPU has 3DNOW (3DNOW base instructions) {AMD} (deprecated). + k3DNOW2, //!< CPU has 3DNOW2 (enhanced 3DNOW) {AMD} (deprecated). + kADX, //!< CPU has ADX (multi-precision add-carry instruction extensions). + kAESNI, //!< CPU has AESNI (AES encode/decode instructions). + kALTMOVCR8, //!< CPU has LOCK MOV R<->CR0 (supports `MOV R<->CR8` via `LOCK MOV R<->CR0` in 32-bit mode) {AMD}. + kAMX_BF16, //!< CPU has AMX_BF16 (AMX-BF16 instructions). + kAMX_COMPLEX, //!< CPU has AMX_COMPLEX (AMX-COMPLEX instructions). + kAMX_FP16, //!< CPU has AMX_FP16 (AMX-FP16 instructions). + kAMX_INT8, //!< CPU has AMX_INT8 (AMX-INT8 instructions). + kAMX_TILE, //!< CPU has AMX_TILE (advanced matrix extensions). + kAPX_F, //!< CPU has APX_F (advanced performance extensions - 32 GP registers, REX2 prefix, ...) {X86_64}. + kAVX, //!< CPU has AVX (advanced vector extensions). + kAVX2, //!< CPU has AVX2 (advanced vector extensions 2). + kAVX512_4FMAPS, //!< CPU has AVX512_FMAPS (FMA packed single). + kAVX512_4VNNIW, //!< CPU has AVX512_VNNIW (vector NN instructions word variable precision). + kAVX512_BF16, //!< CPU has AVX512_BF16 (AVX512 BFLOAT16 support instructions). + kAVX512_BITALG, //!< CPU has AVX512_BITALG (AVX512 VPOPCNT[B|W] and VPSHUFBITQMB instructions). + kAVX512_BW, //!< CPU has AVX512_BW (AVX512 integer BYTE|WORD instructions). + kAVX512_CD, //!< CPU has AVX512_CD (AVX512 conflict detection DWORD|QWORD instructions). + kAVX512_DQ, //!< CPU has AVX512_DQ (AVX512 integer DWORD|QWORD instructions). + kAVX512_ER, //!< CPU has AVX512_ER (AVX512 exponential and reciprocal instructions). + kAVX512_F, //!< CPU has AVX512_F (AVX512 foundation). + kAVX512_FP16, //!< CPU has AVX512_FP16 (AVX512 FP16 instructions). + kAVX512_IFMA, //!< CPU has AVX512_IFMA (AVX512 integer fused-multiply-add using 52-bit precision). + kAVX512_PF, //!< CPU has AVX512_PF (AVX512 prefetch instructions). + kAVX512_VBMI, //!< CPU has AVX512_VBMI (AVX152 vector byte manipulation instructions). + kAVX512_VBMI2, //!< CPU has AVX512_VBMI2 (AVX512 vector byte manipulation instructions v2). + kAVX512_VL, //!< CPU has AVX512_VL (AVX512 vector length extensions). + kAVX512_VNNI, //!< CPU has AVX512_VNNI (AVX512 vector neural network instructions). + kAVX512_VP2INTERSECT, //!< CPU has AVX512_VP2INTERSECT + kAVX512_VPOPCNTDQ, //!< CPU has AVX512_VPOPCNTDQ (AVX512 VPOPCNT[D|Q] instructions). + kAVX_IFMA, //!< CPU has AVX_IFMA (AVX/VEX encoding of vpmadd52huq/vpmadd52luq). + kAVX_NE_CONVERT, //!< CPU has AVX_NE_CONVERT. + kAVX_VNNI, //!< CPU has AVX_VNNI (AVX/VEX encoding of vpdpbusd/vpdpbusds/vpdpwssd/vpdpwssds). + kAVX_VNNI_INT16, //!< CPU has AVX_VNNI_INT16. + kAVX_VNNI_INT8, //!< CPU has AVX_VNNI_INT8. + kBMI, //!< CPU has BMI (bit manipulation instructions #1). + kBMI2, //!< CPU has BMI2 (bit manipulation instructions #2). + kCET_IBT, //!< CPU has CET-IBT (indirect branch tracking). + kCET_SS, //!< CPU has CET-SS. + kCET_SSS, //!< CPU has CET-SSS. + kCLDEMOTE, //!< CPU has CLDEMOTE (cache line demote). + kCLFLUSH, //!< CPU has CLFUSH (cache Line flush). + kCLFLUSHOPT, //!< CPU has CLFUSHOPT (cache Line flush - optimized). + kCLWB, //!< CPU has CLWB. + kCLZERO, //!< CPU has CLZERO. + kCMOV, //!< CPU has CMOV (CMOV and FCMOV instructions). + kCMPCCXADD, //!< CPU has CMPCCXADD. + kCMPXCHG16B, //!< CPU has CMPXCHG16B (compare-exchange 16 bytes) {X86_64}. + kCMPXCHG8B, //!< CPU has CMPXCHG8B (compare-exchange 8 bytes). + kENCLV, //!< CPU has ENCLV. + kENQCMD, //!< CPU has ENQCMD (enqueue stores). + kERMS, //!< CPU has ERMS (enhanced REP MOVSB/STOSB). + kF16C, //!< CPU has F16C (AVX FP16 conversion instructions). + kFMA, //!< CPU has FMA (AVX fused-multiply-add - 3 operand form). + kFMA4, //!< CPU has FMA4 (AVX fused-multiply-add - 4 operand form) (deprecated). + kFPU, //!< CPU has FPU (FPU support). + kFSGSBASE, //!< CPU has FSGSBASE. + kFSRM, //!< CPU has FSRM (fast short REP MOVSB). + kFSRC, //!< CPU has FSRC (fast short REP CMPSB|SCASB). + kFSRS, //!< CPU has FSRS (fast short REP STOSB) + kFXSR, //!< CPU has FXSR (FXSAVE/FXRSTOR instructions). + kFXSROPT, //!< CPU has FXSROTP (FXSAVE/FXRSTOR is optimized). + kFZRM, //!< CPU has FZRM (fast zero-length REP MOVSB). + kGEODE, //!< CPU has GEODE extensions (GEODE 3DNOW additions) (deprecated). + kGFNI, //!< CPU has GFNI (galois field instructions). + kHLE, //!< CPU has HLE. + kHRESET, //!< CPU has HRESET. + kI486, //!< CPU has I486 features (I486+ support). + kINVLPGB, //!< CPU has INVLPGB. + kLAHFSAHF, //!< CPU has LAHF/SAHF (LAHF/SAHF in 64-bit mode) {X86_64}. + kLAM, //!< CPU has LAM (linear address masking) {X86_64}. + kLWP, //!< CPU has LWP (lightweight profiling) {AMD}. + kLZCNT, //!< CPU has LZCNT (LZCNT instruction). + kMCOMMIT, //!< CPU has MCOMMIT (MCOMMIT instruction). + kMMX, //!< CPU has MMX (MMX base instructions) (deprecated). + kMMX2, //!< CPU has MMX2 (MMX2 extensions or initial SSE extensions) (deprecated). + kMONITOR, //!< CPU has MONITOR (MONITOR/MWAIT instructions). + kMONITORX, //!< CPU has MONITORX (MONITORX/MWAITX instructions). + kMOVBE, //!< CPU has MOVBE (move with byte-order swap). + kMOVDIR64B, //!< CPU has MOVDIR64B (move 64 bytes as direct store). + kMOVDIRI, //!< CPU has MOVDIRI (move dword/qword as direct store). + kMPX, //!< CPU has MPX (memory protection extensions). + kMSR, //!< CPU has MSR (RDMSR/WRMSR instructions). + kMSRLIST, //!< CPU has MSRLIST. + kMSSE, //!< CPU has MSSE (misaligned SSE support). + kOSXSAVE, //!< CPU has OSXSAVE (XSAVE enabled by OS). + kOSPKE, //!< CPU has OSPKE (PKE enabled by OS). + kPCLMULQDQ, //!< CPU has PCLMULQDQ (packed carry-less multiplication). + kPCONFIG, //!< CPU has PCONFIG (PCONFIG instruction). + kPOPCNT, //!< CPU has POPCNT (POPCNT instruction). + kPREFETCHI, //!< CPU has PREFETCHI. + kPREFETCHW, //!< CPU has PREFETCHW. + kPREFETCHWT1, //!< CPU has PREFETCHWT1. + kPTWRITE, //!< CPU has PTWRITE. + kRAO_INT, //!< CPU has RAO_INT (AADD, AAND, AOR, AXOR instructions). + kRMPQUERY, //!< CPU has RMPQUERY (RMPQUERY instruction). + kRDPID, //!< CPU has RDPID (RDPID instruction). + kRDPRU, //!< CPU has RDPRU (RDPRU instruction). + kRDRAND, //!< CPU has RDRAND (RDRAND instruction). + kRDSEED, //!< CPU has RDSEED (RDSEED instruction). + kRDTSC, //!< CPU has RDTSC. + kRDTSCP, //!< CPU has RDTSCP. + kRTM, //!< CPU has RTM. + kSEAM, //!< CPU has SEAM. + kSERIALIZE, //!< CPU has SERIALIZE. + kSEV, //!< CPU has SEV (secure encrypted virtualization). + kSEV_ES, //!< CPU has SEV_ES (SEV encrypted state). + kSEV_SNP, //!< CPU has SEV_SNP (SEV secure nested paging). + kSHA, //!< CPU has SHA (SHA-1 and SHA-256 instructions). + kSHA512, //!< CPU has SHA512 (SHA-512 instructions). + kSKINIT, //!< CPU has SKINIT (SKINIT/STGI instructions) {AMD}. + kSM3, //!< CPU has SM3 (SM3 hash extensions). + kSM4, //!< CPU has SM4 (SM4 cipher extensions). + kSMAP, //!< CPU has SMAP (supervisor-mode access prevention). + kSME , //!< CPU has SME (secure memory encryption). + kSMEP, //!< CPU has SMEP (supervisor-mode execution prevention). + kSMX, //!< CPU has SMX (safer mode extensions). + kSSE, //!< CPU has SSE (SSE instructions). + kSSE2, //!< CPU has SSE2 (SSE2 instructions). + kSSE3, //!< CPU has SSE3 (SSE3 instructions). + kSSE4_1, //!< CPU has SSE4.1 (SSE4.1 instructions). + kSSE4_2, //!< CPU has SSE4.2 (SSE4.2 instructions). + kSSE4A, //!< CPU has SSE4A (SSE4.A instructions) {AMD} (deprecated). + kSSSE3, //!< CPU has SSSE3 (SSSE3 instructions). + kSVM, //!< CPU has SVM (virtualization) {AMD}. + kTBM, //!< CPU has TBM (trailing bit manipulation) {AMD}. + kTSE, //!< CPU has TSE. + kTSX, //!< CPU has TSX. + kTSXLDTRK, //!< CPU has TSXLDTRK. + kUINTR, //!< CPU has UINTR (user interrupts). + kVAES, //!< CPU has VAES (vector AES 256|512 bit support). + kVMX, //!< CPU has VMX (virtualization) {INTEL}. + kVPCLMULQDQ, //!< CPU has VPCLMULQDQ (vector PCLMULQDQ 256|512-bit support). + kWAITPKG, //!< CPU has WAITPKG (UMONITOR, UMWAIT, TPAUSE). + kWBNOINVD, //!< CPU has WBNOINVD. + kWRMSRNS, //!< CPU has WRMSRNS. + kXOP, //!< CPU has XOP (XOP instructions) {AMD} (deprecated). + kXSAVE, //!< CPU has XSAVE. + kXSAVEC, //!< CPU has XSAVEC. + kXSAVEOPT, //!< CPU has XSAVEOPT. + kXSAVES, //!< CPU has XSAVES. + // @EnumValuesEnd@ + +#ifndef ASMJIT_NO_DEPRECATED + kAVX512_CDI = kAVX512_CD, + kAVX512_ERI = kAVX512_ER, + kAVX512_PFI = kAVX512_PF, +#endif + + kMaxValue = kXSAVES + }; + + #define ASMJIT_X86_FEATURE(FEATURE) \ + /*! Tests whether FEATURE is present. */ \ + ASMJIT_INLINE_NODEBUG bool has##FEATURE() const noexcept { return has(X86::k##FEATURE); } + + ASMJIT_X86_FEATURE(MT) + ASMJIT_X86_FEATURE(NX) + ASMJIT_X86_FEATURE(3DNOW) + ASMJIT_X86_FEATURE(3DNOW2) + ASMJIT_X86_FEATURE(ADX) + ASMJIT_X86_FEATURE(AESNI) + ASMJIT_X86_FEATURE(ALTMOVCR8) + ASMJIT_X86_FEATURE(AMX_BF16) + ASMJIT_X86_FEATURE(AMX_COMPLEX) + ASMJIT_X86_FEATURE(AMX_FP16) + ASMJIT_X86_FEATURE(AMX_INT8) + ASMJIT_X86_FEATURE(AMX_TILE) + ASMJIT_X86_FEATURE(APX_F) + ASMJIT_X86_FEATURE(AVX) + ASMJIT_X86_FEATURE(AVX2) + ASMJIT_X86_FEATURE(AVX512_4FMAPS) + ASMJIT_X86_FEATURE(AVX512_4VNNIW) + ASMJIT_X86_FEATURE(AVX512_BF16) + ASMJIT_X86_FEATURE(AVX512_BITALG) + ASMJIT_X86_FEATURE(AVX512_BW) + ASMJIT_X86_FEATURE(AVX512_CD) + ASMJIT_X86_FEATURE(AVX512_DQ) + ASMJIT_X86_FEATURE(AVX512_ER) + ASMJIT_X86_FEATURE(AVX512_F) + ASMJIT_X86_FEATURE(AVX512_FP16) + ASMJIT_X86_FEATURE(AVX512_IFMA) + ASMJIT_X86_FEATURE(AVX512_PF) + ASMJIT_X86_FEATURE(AVX512_VBMI) + ASMJIT_X86_FEATURE(AVX512_VBMI2) + ASMJIT_X86_FEATURE(AVX512_VL) + ASMJIT_X86_FEATURE(AVX512_VNNI) + ASMJIT_X86_FEATURE(AVX512_VP2INTERSECT) + ASMJIT_X86_FEATURE(AVX512_VPOPCNTDQ) + ASMJIT_X86_FEATURE(AVX_IFMA) + ASMJIT_X86_FEATURE(AVX_NE_CONVERT) + ASMJIT_X86_FEATURE(AVX_VNNI) + ASMJIT_X86_FEATURE(AVX_VNNI_INT16) + ASMJIT_X86_FEATURE(AVX_VNNI_INT8) + ASMJIT_X86_FEATURE(BMI) + ASMJIT_X86_FEATURE(BMI2) + ASMJIT_X86_FEATURE(CET_IBT) + ASMJIT_X86_FEATURE(CET_SS) + ASMJIT_X86_FEATURE(CET_SSS) + ASMJIT_X86_FEATURE(CLDEMOTE) + ASMJIT_X86_FEATURE(CLFLUSH) + ASMJIT_X86_FEATURE(CLFLUSHOPT) + ASMJIT_X86_FEATURE(CLWB) + ASMJIT_X86_FEATURE(CLZERO) + ASMJIT_X86_FEATURE(CMOV) + ASMJIT_X86_FEATURE(CMPXCHG16B) + ASMJIT_X86_FEATURE(CMPXCHG8B) + ASMJIT_X86_FEATURE(ENCLV) + ASMJIT_X86_FEATURE(ENQCMD) + ASMJIT_X86_FEATURE(ERMS) + ASMJIT_X86_FEATURE(F16C) + ASMJIT_X86_FEATURE(FMA) + ASMJIT_X86_FEATURE(FMA4) + ASMJIT_X86_FEATURE(FPU) + ASMJIT_X86_FEATURE(FSGSBASE) + ASMJIT_X86_FEATURE(FSRM) + ASMJIT_X86_FEATURE(FSRC) + ASMJIT_X86_FEATURE(FSRS) + ASMJIT_X86_FEATURE(FXSR) + ASMJIT_X86_FEATURE(FXSROPT) + ASMJIT_X86_FEATURE(FZRM) + ASMJIT_X86_FEATURE(GEODE) + ASMJIT_X86_FEATURE(GFNI) + ASMJIT_X86_FEATURE(HLE) + ASMJIT_X86_FEATURE(HRESET) + ASMJIT_X86_FEATURE(I486) + ASMJIT_X86_FEATURE(INVLPGB) + ASMJIT_X86_FEATURE(LAHFSAHF) + ASMJIT_X86_FEATURE(LAM) + ASMJIT_X86_FEATURE(LWP) + ASMJIT_X86_FEATURE(LZCNT) + ASMJIT_X86_FEATURE(MCOMMIT) + ASMJIT_X86_FEATURE(MMX) + ASMJIT_X86_FEATURE(MMX2) + ASMJIT_X86_FEATURE(MONITOR) + ASMJIT_X86_FEATURE(MONITORX) + ASMJIT_X86_FEATURE(MOVBE) + ASMJIT_X86_FEATURE(MOVDIR64B) + ASMJIT_X86_FEATURE(MOVDIRI) + ASMJIT_X86_FEATURE(MPX) + ASMJIT_X86_FEATURE(MSR) + ASMJIT_X86_FEATURE(MSRLIST) + ASMJIT_X86_FEATURE(MSSE) + ASMJIT_X86_FEATURE(OSXSAVE) + ASMJIT_X86_FEATURE(OSPKE) + ASMJIT_X86_FEATURE(PCLMULQDQ) + ASMJIT_X86_FEATURE(PCONFIG) + ASMJIT_X86_FEATURE(POPCNT) + ASMJIT_X86_FEATURE(PREFETCHI) + ASMJIT_X86_FEATURE(PREFETCHW) + ASMJIT_X86_FEATURE(PREFETCHWT1) + ASMJIT_X86_FEATURE(PTWRITE) + ASMJIT_X86_FEATURE(RAO_INT) + ASMJIT_X86_FEATURE(RMPQUERY) + ASMJIT_X86_FEATURE(RDPID) + ASMJIT_X86_FEATURE(RDPRU) + ASMJIT_X86_FEATURE(RDRAND) + ASMJIT_X86_FEATURE(RDSEED) + ASMJIT_X86_FEATURE(RDTSC) + ASMJIT_X86_FEATURE(RDTSCP) + ASMJIT_X86_FEATURE(RTM) + ASMJIT_X86_FEATURE(SEAM) + ASMJIT_X86_FEATURE(SERIALIZE) + ASMJIT_X86_FEATURE(SEV) + ASMJIT_X86_FEATURE(SEV_ES) + ASMJIT_X86_FEATURE(SEV_SNP) + ASMJIT_X86_FEATURE(SHA) + ASMJIT_X86_FEATURE(SKINIT) + ASMJIT_X86_FEATURE(SMAP) + ASMJIT_X86_FEATURE(SMEP) + ASMJIT_X86_FEATURE(SMX) + ASMJIT_X86_FEATURE(SSE) + ASMJIT_X86_FEATURE(SSE2) + ASMJIT_X86_FEATURE(SSE3) + ASMJIT_X86_FEATURE(SSE4_1) + ASMJIT_X86_FEATURE(SSE4_2) + ASMJIT_X86_FEATURE(SSE4A) + ASMJIT_X86_FEATURE(SSSE3) + ASMJIT_X86_FEATURE(SVM) + ASMJIT_X86_FEATURE(TBM) + ASMJIT_X86_FEATURE(TSE) + ASMJIT_X86_FEATURE(TSX) + ASMJIT_X86_FEATURE(TSXLDTRK) + ASMJIT_X86_FEATURE(UINTR) + ASMJIT_X86_FEATURE(VAES) + ASMJIT_X86_FEATURE(VMX) + ASMJIT_X86_FEATURE(VPCLMULQDQ) + ASMJIT_X86_FEATURE(WAITPKG) + ASMJIT_X86_FEATURE(WBNOINVD) + ASMJIT_X86_FEATURE(WRMSRNS) + ASMJIT_X86_FEATURE(XOP) + ASMJIT_X86_FEATURE(XSAVE) + ASMJIT_X86_FEATURE(XSAVEC) + ASMJIT_X86_FEATURE(XSAVEOPT) + ASMJIT_X86_FEATURE(XSAVES) + +#ifndef ASMJIT_NO_DEPRECATED + ASMJIT_DEPRECATED("Use hasAVX512_CD() instead") + ASMJIT_X86_FEATURE(AVX512_CDI) + + ASMJIT_DEPRECATED("Use hasAVX512_ER() instead") + ASMJIT_X86_FEATURE(AVX512_ERI) + + ASMJIT_DEPRECATED("Use hasAVX512_PF() instead") + ASMJIT_X86_FEATURE(AVX512_PFI) +#endif + + #undef ASMJIT_X86_FEATURE + }; + + //! ARM specific features data. + //! + //! Naming reference: + //! - https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile + struct ARM : public Data { + //! ARM CPU feature identifiers. + enum Id : uint8_t { + // @EnumValuesBegin{"enum": "CpuFeatures::ARM"}@ + kNone = 0, //!< No feature (never set, used internally). + + kARMv6, //!< CPU is at least ARMv6 {A32}. + kARMv7, //!< CPU is at least ARMv7 {A32}. + kARMv8a, //!< CPU is at least ARMv8A. + kTHUMB, //!< CPU has THUMB (16-bit THUMB encoding) {A32}. + kTHUMBv2, //!< CPU has THUMBv2 (32-bit THUMB encoding) {A32}. + + kABLE, //!< CPU has ABLE (address breakpoint linking extension) {A64}. + kADERR, //!< CPU has ADERR (asynchronous device error exceptions) {A64}. + kAES, //!< CPU has AES (ASIMD AES instructions). + kAFP, //!< CPU has AFP (alternate floating-point behavior) {A64}. + kAIE, //!< CPU has AIE (memory attribute index enhancement) {A64}. + kAMU1, //!< CPU has AMUv1 (activity monitors extension version 1) {A64}. + kAMU1_1, //!< CPU has AMUv1p1 (activity monitors extension version 1.1) {A64}. + kANERR, //!< CPU has ANERR (asynchronous normal error exception) {A64}. + kASIMD, //!< CPU has ASIMD (NEON on ARM/THUMB). + kBF16, //!< CPU has BF16 (BFloat16 instructions) {A64}. + kBRBE, //!< CPU has BRBE (branch record buffer extension) {A64}. + kBTI, //!< CPU has BTI (branch target identification). + kBWE, //!< CPU has BWE (breakpoint mismatch and range extension) {A64}. + kCCIDX, //!< CPU has CCIDX (extend of the CCSIDR number of sets). + kCHK, //!< CPU has CHK (check feature status - CHKFEAT instruction) {A64}. + kCLRBHB, //!< CPU has CLRBHB (clear BHB instruction). + kCMOW, //!< CPU has CMOW (control for cache maintenance permission) {A64}. + kCONSTPACFIELD, //!< CPU has CONSTPACFIELD (PAC algorithm enhancement) {A64}. + kCPA, //!< CPU has CPA (instruction-only Checked Pointer Arithmetic) {A64}. + kCPA2, //!< CPU has CPA2 (checked Pointer Arithmetic) {A64}. + kCPUID, //!< CPU has CPUID (CPUID registers accessible in user-space). + kCRC32, //!< CPU has CRC32 (CRC32 instructions). + kCSSC, //!< CPU has CSSC (common short sequence compression) {A64}. + kCSV2, //!< CPU has CSV2 (cache speculation variant 2 version 2.1) {A64}. + kCSV2_3, //!< CPU has CSV2_3 (cache speculation variant 2 version 3) {A64}. + kCSV3, //!< CPU has CSV3 (cache speculation Variant 3) {A64}. + kD128, //!< CPU has D128 (128-bit translation tables, 56 bit PA) {A64}. + kDGH, //!< CPU has DGH (data gathering hint) {A64}. + kDIT, //!< CPU has DIT (data independent timing of instructions). + kDOTPROD, //!< CPU has DOTPROD (ASIMD Int8 dot product instructions). + kDPB, //!< CPU has DPB (DC CVAP instruction) {A64}. + kDPB2, //!< CPU has DPB2 (DC CVADP instruction) {A64}. + kEBEP, //!< CPU has EBEP (exception-based event profiling) {A64}. + kEBF16, //!< CPU has EBF16 (extended BFloat16 mode) {A64}. + kECBHB, //!< CPU has ECBHB (exploitative control using branch history information) {A64}. + kECV, //!< CPU has ECV (enhanced counter virtualization). + kEDHSR, //!< CPU has EDHSR (support for EDHSR) {A64}. + kEDSP, //!< CPU has EDSP (ARM/THUMB only). + kFAMINMAX, //!< CPU has FAMINMAX (floating-point maximum and minimum absolute value instructions) {A64}. + kFCMA, //!< CPU has FCMA (FCADD/FCMLA). + kFGT, //!< CPU has FGT (fine-grained traps). + kFGT2, //!< CPU has FGT2 (fine-grained traps 2). + kFHM, //!< CPU has FHM (half-precision floating-point FMLAL instructions). + kFLAGM, //!< CPU has FLAGM (condition flag manipulation) {A64}. + kFLAGM2, //!< CPU has FLAGM2 (condition flag manipulation version v2) {A64}. + kFMAC, //!< CPU has FMAC (ARM/THUMB only). + kFP, //!< CPU has FP (floating-point) (on 32-bit ARM this means VFPv3). + kFP16, //!< CPU has FP16 (half-precision floating-point data processing). + kFP16CONV, //!< CPU has FP16CONV (half-precision float conversion). + kFP8, //!< CPU has FP8 (FP8 convert instructions) {A64}. + kFP8DOT2, //!< CPU has FP8DOT2 (FP8 2-way dot product to half-precision instructions) {A64}. + kFP8DOT4, //!< CPU has FP8DOT4 (FP8 4-way dot product to single-precision instructions) {A64}. + kFP8FMA, //!< CPU has FP8FMA (FP8 multiply-accumulate to half-precision and single-precision instructions) {A64}. + kFPMR, //!< CPU has FPMR (floating-point Mode Register) {A64}. + kFRINTTS, //!< CPU has FRINTTS (FRINT[32|64][X|Z] instructions) {A64}. + kGCS, //!< CPU has GCS (guarded control stack extension) {A64}. + kHACDBS, //!< CPU has HACDBS (hardware accelerator for cleaning Dirty state) {A64}. + kHAFDBS, //!< CPU has HAFDBS (hardware management of the access flag and dirty state) {A64}. + kHAFT, //!< CPU has HAFT (hardware managed access flag for table descriptors) {A64}. + kHDBSS, //!< CPU has HDBSS (hardware Dirty state tracking Structure) {A64}. + kHBC, //!< CPU has HBC (hinted conditional branches) {A64}. + kHCX, //!< CPU has HCX (support for the HCRX_EL2 register) {A64}. + kHPDS, //!< CPU has HPDS (hierarchical permission disables in translation tables ) {A64}. + kHPDS2, //!< CPU has HPDS2 (hierarchical permission disables) {A64}. + kI8MM, //!< CPU has I8MM (int8 matrix multiplication) {A64}. + kIDIVA, //!< CPU has IDIV (hardware SDIV and UDIV in ARM mode). + kIDIVT, //!< CPU has IDIV (hardware SDIV and UDIV in THUMB mode). + kITE, //!< CPU has ITE (instrumentation extension) {A64}. + kJSCVT, //!< CPU has JSCVT (JavaScript FJCVTS conversion instruction) {A64}. + kLOR, //!< CPU has LOR (limited ordering regions extension). + kLRCPC, //!< CPU has LRCPC (load-acquire RCpc instructions) {A64}. + kLRCPC2, //!< CPU has LRCPC2 (load-acquire RCpc instructions v2) {A64}. + kLRCPC3, //!< CPU has LRCPC3 (load-Acquire RCpc instructions v3) {A64}. + kLS64, //!< CPU has LS64 (64 byte loads/stores without return) {A64}. + kLS64_ACCDATA, //!< CPU has LS64_ACCDATA (64-byte EL0 stores with return) {A64}. + kLS64_V, //!< CPU has LS64_V (64-byte stores with return) {A64}. + kLSE, //!< CPU has LSE (large system extensions) {A64}. + kLSE128, //!< CPU has LSE128 (128-bit atomics) {A64}. + kLSE2, //!< CPU has LSE2 (large system extensions v2) {A64}. + kLUT, //!< CPU has LUT (lookup table instructions with 2-bit and 4-bit indices) {A64}. + kLVA, //!< CPU has LVA (large VA support) {A64}. + kLVA3, //!< CPU has LVA3 (56-bit VA) {A64}. + kMEC, //!< CPU has MEC (memory encryption contexts) {A64}. + kMOPS, //!< CPU has MOPS (memcpy and memset acceleration instructions) {A64}. + kMPAM, //!< CPU has MPAM (memory system partitioning and monitoring extension) {A64}. + kMTE, //!< CPU has MTE (instruction-only memory tagging extension) {A64}. + kMTE2, //!< CPU has MTE2 (full memory tagging extension) {A64}. + kMTE3, //!< CPU has MTE3 (MTE asymmetric fault handling) {A64}. + kMTE4, //!< CPU has MTE4 (MTE v4) {A64}. + kMTE_ASYM_FAULT, //!< CPU has MTE_ASYM_FAULT (memory tagging asymmetric faults) {A64}. + kMTE_ASYNC, //!< CPU has MTE_ASYNC (memory tagging asynchronous faulting) {A64}. + kMTE_CANONICAL_TAGS, //!< CPU has MTE_CANONICAL_TAGS (canonical tag checking for untagged memory) {A64}. + kMTE_NO_ADDRESS_TAGS, //!< CPU has MTE_NO_ADDRESS_TAGS (memory tagging with address tagging disabled) {A64}. + kMTE_PERM_S1, //!< CPU has MTE_PERM_S1 (allocation tag access permission) {A64}. + kMTE_STORE_ONLY, //!< CPU has MTE_STORE_ONLY (store-only tag checking) {A64}. + kMTE_TAGGED_FAR, //!< CPU has MTE_TAGGED_FAR (FAR_ELx on a tag check fault) {A64}. + kMTPMU, //!< CPU has MTPMU (multi-threaded PMU extensions) {A64}. + kNMI, //!< CPU has NMI (non-maskable Interrupt) {A64}. + kNV, //!< CPU has NV (nested virtualization enchancement) {A64}. + kNV2, //!< CPU has NV2 (enhanced support for nested virtualization) {A64}. + kPAN, //!< CPU has PAN (privileged access-never extension) {A64}. + kPAN2, //!< CPU has PAN2 (PAN s1e1R and s1e1W variants) {A64}. + kPAN3, //!< CPU has PAN3 (support for SCTLR_ELx.EPAN) {A64}. + kPAUTH, //!< CPU has PAUTH (pointer authentication extension) {A64}. + kPFAR, //!< CPU has PFAR (physical fault address registers) {A64}. + kPMU, //!< CPU has PMU {A64}. + kPMULL, //!< CPU has PMULL (ASIMD PMULL instructions) {A64}. + kPRFMSLC, //!< CPU has PRFMSLC (PRFM instructions support the SLC target) {A64}. + kRAS, //!< CPU has RAS (reliability, availability and serviceability extensions). + kRAS1_1, //!< CPU has RASv1p1 (RAS v1.1). + kRAS2, //!< CPU has RASv2 (RAS v2). + kRASSA2, //!< CPU has RASSAv2 (RAS v2 system architecture). + kRDM, //!< CPU has RDM (rounding double multiply accumulate) {A64}. + kRME, //!< CPU has RME (memory encryption contexts extension) {A64}. + kRNG, //!< CPU has RNG (random number generation). + kRNG_TRAP, //!< CPU has RNG_TRAP (random number trap to EL3 field) {A64}. + kRPRES, //!< CPU has RPRES (increased precision of reciprocal estimate and RSQRT estimate) {A64}. + kRPRFM, //!< CPU has RPRFM (range prefetch hint instruction). + kS1PIE, //!< CPU has S1PIE (permission model enhancements) {A64}. + kS1POE, //!< CPU has S1POE (permission model enhancements) {A64}. + kS2PIE, //!< CPU has S2PIE (permission model enhancements) {A64}. + kS2POE, //!< CPU has S2POE (permission model enhancements) {A64}. + kSB, //!< CPU has SB (speculative barrier). + kSCTLR2, //!< CPU has SCTLR2 (extension to SCTLR_ELx) {A64}. + kSEBEP, //!< CPU has SEBEP (synchronous exception-based event profiling) {A64}. + kSEL2, //!< CPU has SEL2 (secure EL2) {A64}. + kSHA1, //!< CPU has SHA1 (ASIMD SHA1 instructions). + kSHA256, //!< CPU has SHA256 (ASIMD SHA256 instructions). + kSHA3, //!< CPU has SHA3 (ASIMD EOR3, RAX1, XAR, and BCAX instructions). + kSHA512, //!< CPU has SHA512 (ASIMD SHA512 instructions). + kSM3, //!< CPU has SM3 (ASIMD SM3 instructions). + kSM4, //!< CPU has SM4 (ASIMD SM4 instructions). + kSME, //!< CPU has SME (SME v1 - scalable matrix extension) {A64}. + kSME2, //!< CPU has SME2 (SME v2) {A64}. + kSME2_1, //!< CPU has SME2p1 (SME v2.1) {A64}. + kSME_B16B16, //!< CPU has SME_B16B16 (SME non-widening BFloat16 to BFloat16 arithmetic) {A64}. + kSME_B16F32, //!< CPU has SME_B16F32 (BFMOPA and BFMOPS instructions that accumulate BFloat16 outer products into single-precision tiles) {A64}. + kSME_BI32I32, //!< CPU has SME_BI32I32 (BMOPA and BMOPS instructions that accumulate 1-bit binary outer products into 32-bit integer tiles) {A64}. + kSME_F16F16, //!< CPU has SME_F16F16 (SME2.1 non-widening half-precision FP16 to FP16 arithmetic) {A64}. + kSME_F16F32, //!< CPU has SME_F16F32 {A64}. + kSME_F32F32, //!< CPU has SME_F32F32 {A64}. + kSME_F64F64, //!< CPU has SME_F64F64 {A64}. + kSME_F8F16, //!< CPU has SME_F8F16 (SME2 ZA-targeting FP8 multiply-accumulate, dot product, and outer product to half-precision instructions) {A64}. + kSME_F8F32, //!< CPU has SME_F8F32 (SME2 ZA-targeting FP8 multiply-accumulate, dot product, and outer product to single-precision instructions) {A64}. + kSME_FA64, //!< CPU has SME_FA64 {A64}. + kSME_I16I32, //!< CPU has SME_I16I32 {A64}. + kSME_I16I64, //!< CPU has SME_I16I64 {A64}. + kSME_I8I32, //!< CPU has SME_I8I32 {A64}. + kSME_LUTv2, //!< CPU has SME_LUTv2 (lookup table instructions with 4-bit indices and 8-bit elements) {A64}. + kSPE, //!< CPU has SPE (statistical profiling extension) {A64}. + kSPE1_1, //!< CPU has SPEv1p1 (statistical profiling extensions version 1.1) {A64}. + kSPE1_2, //!< CPU has SPEv1p2 (statistical profiling extensions version 1.2) {A64}. + kSPE1_3, //!< CPU has SPEv1p3 (statistical profiling extensions version 1.3) {A64}. + kSPE1_4, //!< CPU has SPEv1p4 (statistical profiling extensions version 1.4) {A64}. + kSPE_ALTCLK, //!< CPU has SPE_ALTCLK (statistical profiling alternate clock domain extension) {A64}. + kSPE_CRR, //!< CPU has SPE_CRR (statistical profiling call return branch records) {A64}. + kSPE_EFT, //!< CPU has SPE_EFT (statistical profiling extended filtering by type) {A64}. + kSPE_FDS, //!< CPU has SPE_FDS (statistical profiling data source filtering) {A64}. + kSPE_FPF, //!< CPU has SPE_FPF (statistical profiling floating-point flag extension) {A64}. + kSPE_SME, //!< CPU has SPE_SME (statistical profiling extensions for SME) {A64}. + kSPECRES, //!< CPU has SPECRES (speculation restriction instructions). + kSPECRES2, //!< CPU has SPECRES2 (clear other speculative predictions). + kSPMU, //!< CPU has SPMU (system performance monitors extension) {A64}. + kSSBS, //!< CPU has SSBS (speculative store bypass safe instruction). + kSSBS2, //!< CPU has SSBS2 (MRS and MSR instructions for SSBS). + kSSVE_FP8DOT2, //!< CPU has SSVE_FP8DOT2 (SVE2 FP8 2-way dot product to half-precision instructions in Streaming SVE mode) {A64}. + kSSVE_FP8DOT4, //!< CPU has SSVE_FP8DOT4 (SVE2 FP8 4-way dot product to single-precision instructions in Streaming SVE mode) {A64}. + kSSVE_FP8FMA, //!< CPU has SSVE_FP8FMA (SVE2 FP8 multiply-accumulate to half-precision and single-precision instructions in Streaming SVE mode) {A64}. + kSVE, //!< CPU has SVE (SVE v1 - scalable vector extension) {A64}. + kSVE2, //!< CPU has SVE2 (SVE v2) {A64}. + kSVE2_1, //!< CPU has SVE2p1 (SVE v2.1) {A64}. + kSVE_AES, //!< CPU has SVE_AES (SVE AES instructions) {A64}. + kSVE_B16B16, //!< CPU has SVE_B16B16 (SVE non-widening BFloat16 to BFloat16 arithmetic) {A64}. + kSVE_BF16, //!< CPU has SVE_BF16 (SVE BF16 instructions) {A64}. + kSVE_BITPERM, //!< CPU has SVE_BITPERM (SVE bit permute) {A64}. + kSVE_EBF16, //!< CPU has SVE_EBF16 (SVE extended BFloat16 mode) {A64}. + kSVE_F32MM, //!< CPU has SVE_F32MM (SVE single-precision floating-point matrix multiply instruction) {A64}. + kSVE_F64MM, //!< CPU has SVE_F64MM (SVE double-precision floating-point matrix multiply instruction) {A64}. + kSVE_I8MM, //!< CPU has SVE_I8MM (SVE int8 matrix multiplication) {A64}. + kSVE_PMULL128, //!< CPU has SVE_PMULL128 (SVE PMULL instructions) {A64}. + kSVE_SHA3, //!< CPU has SVE_SHA3 (SVE SHA-3 instructions) {A64}. + kSVE_SM4, //!< CPU has SVE_SM4 (SVE SM4 instructions {A64}. + kSYSINSTR128, //!< CPU has SYSINSTR128 (128-bit system instructions) {A64}. + kSYSREG128, //!< CPU has SYSREG128 (128-bit system registers) {A64}. + kTHE, //!< CPU has THE (translation hardening extension). + kTLBIOS, //!< CPU has TLBIOS (TLBI instructions in Outer Shareable domain) {A64}. + kTLBIRANGE, //!< CPU has TLBIRANGE (TLBI range instructions) {A64}. + kTLBIW, //!< CPU has TLBIW (TLBI VMALL for dirty state) {A64}. + kTME, //!< CPU has TME (transactional memory extensions). + kTRF, //!< CPU has TRF (self-hosted trace extensions). + kUAO, //!< CPU has UAO (AArch64 v8.2 UAO PState) {A64}. + kVFP_D32, //!< CPU has VFP_D32 (32 VFP-D registers) (ARM/THUMB only). + kVHE, //!< CPU has VHE (virtual host extension). + kVMID16, //!< CPU has VMID16 (16-bit VMID) {A64}. + kWFXT, //!< CPU has WFxT (WFE and WFI instructions with timeout) {A64}. + kXNX, //!< CPU has XNX (translation table stage 2 unprivileged execute-never) {A64}. + kXS, //!< CPU has XS (XS attribute in TLBI and DSB instructions) {A64}. + // @EnumValuesEnd@ + + kMaxValue = kXS + }; + + #define ASMJIT_ARM_FEATURE(FEATURE) \ + /*! Tests whether FEATURE is present. */ \ + ASMJIT_INLINE_NODEBUG bool has##FEATURE() const noexcept { return has(ARM::k##FEATURE); } + + ASMJIT_ARM_FEATURE(THUMB) + ASMJIT_ARM_FEATURE(THUMBv2) + + ASMJIT_ARM_FEATURE(ARMv6) + ASMJIT_ARM_FEATURE(ARMv7) + ASMJIT_ARM_FEATURE(ARMv8a) + + ASMJIT_ARM_FEATURE(ABLE) + ASMJIT_ARM_FEATURE(ADERR) + ASMJIT_ARM_FEATURE(AES) + ASMJIT_ARM_FEATURE(AFP) + ASMJIT_ARM_FEATURE(AIE) + ASMJIT_ARM_FEATURE(AMU1) + ASMJIT_ARM_FEATURE(AMU1_1) + ASMJIT_ARM_FEATURE(ANERR) + ASMJIT_ARM_FEATURE(ASIMD) + ASMJIT_ARM_FEATURE(BF16) + ASMJIT_ARM_FEATURE(BRBE) + ASMJIT_ARM_FEATURE(BTI) + ASMJIT_ARM_FEATURE(BWE) + ASMJIT_ARM_FEATURE(CCIDX) + ASMJIT_ARM_FEATURE(CHK) + ASMJIT_ARM_FEATURE(CLRBHB) + ASMJIT_ARM_FEATURE(CMOW) + ASMJIT_ARM_FEATURE(CONSTPACFIELD) + ASMJIT_ARM_FEATURE(CPA) + ASMJIT_ARM_FEATURE(CPA2) + ASMJIT_ARM_FEATURE(CPUID) + ASMJIT_ARM_FEATURE(CRC32) + ASMJIT_ARM_FEATURE(CSSC) + ASMJIT_ARM_FEATURE(CSV2) + ASMJIT_ARM_FEATURE(CSV2_3) + ASMJIT_ARM_FEATURE(CSV3) + ASMJIT_ARM_FEATURE(D128) + ASMJIT_ARM_FEATURE(DGH) + ASMJIT_ARM_FEATURE(DIT) + ASMJIT_ARM_FEATURE(DOTPROD) + ASMJIT_ARM_FEATURE(DPB) + ASMJIT_ARM_FEATURE(DPB2) + ASMJIT_ARM_FEATURE(EBEP) + ASMJIT_ARM_FEATURE(EBF16) + ASMJIT_ARM_FEATURE(ECBHB) + ASMJIT_ARM_FEATURE(ECV) + ASMJIT_ARM_FEATURE(EDHSR) + ASMJIT_ARM_FEATURE(EDSP) + ASMJIT_ARM_FEATURE(FAMINMAX) + ASMJIT_ARM_FEATURE(FCMA) + ASMJIT_ARM_FEATURE(FGT) + ASMJIT_ARM_FEATURE(FGT2) + ASMJIT_ARM_FEATURE(FHM) + ASMJIT_ARM_FEATURE(FLAGM) + ASMJIT_ARM_FEATURE(FLAGM2) + ASMJIT_ARM_FEATURE(FMAC) + ASMJIT_ARM_FEATURE(FP) + ASMJIT_ARM_FEATURE(FP16) + ASMJIT_ARM_FEATURE(FP16CONV) + ASMJIT_ARM_FEATURE(FP8) + ASMJIT_ARM_FEATURE(FP8DOT2) + ASMJIT_ARM_FEATURE(FP8DOT4) + ASMJIT_ARM_FEATURE(FP8FMA) + ASMJIT_ARM_FEATURE(FPMR) + ASMJIT_ARM_FEATURE(FRINTTS) + ASMJIT_ARM_FEATURE(GCS) + ASMJIT_ARM_FEATURE(HACDBS) + ASMJIT_ARM_FEATURE(HAFDBS) + ASMJIT_ARM_FEATURE(HAFT) + ASMJIT_ARM_FEATURE(HDBSS) + ASMJIT_ARM_FEATURE(HBC) + ASMJIT_ARM_FEATURE(HCX) + ASMJIT_ARM_FEATURE(HPDS) + ASMJIT_ARM_FEATURE(HPDS2) + ASMJIT_ARM_FEATURE(I8MM) + ASMJIT_ARM_FEATURE(IDIVA) + ASMJIT_ARM_FEATURE(IDIVT) + ASMJIT_ARM_FEATURE(ITE) + ASMJIT_ARM_FEATURE(JSCVT) + ASMJIT_ARM_FEATURE(LOR) + ASMJIT_ARM_FEATURE(LRCPC) + ASMJIT_ARM_FEATURE(LRCPC2) + ASMJIT_ARM_FEATURE(LRCPC3) + ASMJIT_ARM_FEATURE(LS64) + ASMJIT_ARM_FEATURE(LS64_ACCDATA) + ASMJIT_ARM_FEATURE(LS64_V) + ASMJIT_ARM_FEATURE(LSE) + ASMJIT_ARM_FEATURE(LSE128) + ASMJIT_ARM_FEATURE(LSE2) + ASMJIT_ARM_FEATURE(LUT) + ASMJIT_ARM_FEATURE(LVA) + ASMJIT_ARM_FEATURE(LVA3) + ASMJIT_ARM_FEATURE(MEC) + ASMJIT_ARM_FEATURE(MOPS) + ASMJIT_ARM_FEATURE(MPAM) + ASMJIT_ARM_FEATURE(MTE) + ASMJIT_ARM_FEATURE(MTE2) + ASMJIT_ARM_FEATURE(MTE3) + ASMJIT_ARM_FEATURE(MTE4) + ASMJIT_ARM_FEATURE(MTE_ASYM_FAULT) + ASMJIT_ARM_FEATURE(MTE_ASYNC) + ASMJIT_ARM_FEATURE(MTE_CANONICAL_TAGS) + ASMJIT_ARM_FEATURE(MTE_NO_ADDRESS_TAGS) + ASMJIT_ARM_FEATURE(MTE_PERM_S1) + ASMJIT_ARM_FEATURE(MTE_STORE_ONLY) + ASMJIT_ARM_FEATURE(MTE_TAGGED_FAR) + ASMJIT_ARM_FEATURE(MTPMU) + ASMJIT_ARM_FEATURE(NMI) + ASMJIT_ARM_FEATURE(NV) + ASMJIT_ARM_FEATURE(NV2) + ASMJIT_ARM_FEATURE(PAN) + ASMJIT_ARM_FEATURE(PAN2) + ASMJIT_ARM_FEATURE(PAN3) + ASMJIT_ARM_FEATURE(PAUTH) + ASMJIT_ARM_FEATURE(PFAR) + ASMJIT_ARM_FEATURE(PMU) + ASMJIT_ARM_FEATURE(PMULL) + ASMJIT_ARM_FEATURE(PRFMSLC) + ASMJIT_ARM_FEATURE(RAS) + ASMJIT_ARM_FEATURE(RAS1_1) + ASMJIT_ARM_FEATURE(RAS2) + ASMJIT_ARM_FEATURE(RASSA2) + ASMJIT_ARM_FEATURE(RDM) + ASMJIT_ARM_FEATURE(RME) + ASMJIT_ARM_FEATURE(RNG) + ASMJIT_ARM_FEATURE(RNG_TRAP) + ASMJIT_ARM_FEATURE(RPRES) + ASMJIT_ARM_FEATURE(RPRFM) + ASMJIT_ARM_FEATURE(S1PIE) + ASMJIT_ARM_FEATURE(S1POE) + ASMJIT_ARM_FEATURE(S2PIE) + ASMJIT_ARM_FEATURE(S2POE) + ASMJIT_ARM_FEATURE(SB) + ASMJIT_ARM_FEATURE(SCTLR2) + ASMJIT_ARM_FEATURE(SEBEP) + ASMJIT_ARM_FEATURE(SEL2) + ASMJIT_ARM_FEATURE(SHA1) + ASMJIT_ARM_FEATURE(SHA256) + ASMJIT_ARM_FEATURE(SHA3) + ASMJIT_ARM_FEATURE(SHA512) + ASMJIT_ARM_FEATURE(SM3) + ASMJIT_ARM_FEATURE(SM4) + ASMJIT_ARM_FEATURE(SME) + ASMJIT_ARM_FEATURE(SME2) + ASMJIT_ARM_FEATURE(SME2_1) + ASMJIT_ARM_FEATURE(SME_B16B16) + ASMJIT_ARM_FEATURE(SME_B16F32) + ASMJIT_ARM_FEATURE(SME_BI32I32) + ASMJIT_ARM_FEATURE(SME_F16F16) + ASMJIT_ARM_FEATURE(SME_F16F32) + ASMJIT_ARM_FEATURE(SME_F32F32) + ASMJIT_ARM_FEATURE(SME_F64F64) + ASMJIT_ARM_FEATURE(SME_F8F16) + ASMJIT_ARM_FEATURE(SME_F8F32) + ASMJIT_ARM_FEATURE(SME_FA64) + ASMJIT_ARM_FEATURE(SME_I16I32) + ASMJIT_ARM_FEATURE(SME_I16I64) + ASMJIT_ARM_FEATURE(SME_I8I32) + ASMJIT_ARM_FEATURE(SME_LUTv2) + ASMJIT_ARM_FEATURE(SPE) + ASMJIT_ARM_FEATURE(SPE1_1) + ASMJIT_ARM_FEATURE(SPE1_2) + ASMJIT_ARM_FEATURE(SPE1_3) + ASMJIT_ARM_FEATURE(SPE1_4) + ASMJIT_ARM_FEATURE(SPE_ALTCLK) + ASMJIT_ARM_FEATURE(SPE_CRR) + ASMJIT_ARM_FEATURE(SPE_EFT) + ASMJIT_ARM_FEATURE(SPE_FDS) + ASMJIT_ARM_FEATURE(SPE_FPF) + ASMJIT_ARM_FEATURE(SPE_SME) + ASMJIT_ARM_FEATURE(SPECRES) + ASMJIT_ARM_FEATURE(SPECRES2) + ASMJIT_ARM_FEATURE(SPMU) + ASMJIT_ARM_FEATURE(SSBS) + ASMJIT_ARM_FEATURE(SSBS2) + ASMJIT_ARM_FEATURE(SSVE_FP8DOT2) + ASMJIT_ARM_FEATURE(SSVE_FP8DOT4) + ASMJIT_ARM_FEATURE(SSVE_FP8FMA) + ASMJIT_ARM_FEATURE(SVE) + ASMJIT_ARM_FEATURE(SVE2) + ASMJIT_ARM_FEATURE(SVE2_1) + ASMJIT_ARM_FEATURE(SVE_AES) + ASMJIT_ARM_FEATURE(SVE_B16B16) + ASMJIT_ARM_FEATURE(SVE_BF16) + ASMJIT_ARM_FEATURE(SVE_BITPERM) + ASMJIT_ARM_FEATURE(SVE_EBF16) + ASMJIT_ARM_FEATURE(SVE_F32MM) + ASMJIT_ARM_FEATURE(SVE_F64MM) + ASMJIT_ARM_FEATURE(SVE_I8MM) + ASMJIT_ARM_FEATURE(SVE_PMULL128) + ASMJIT_ARM_FEATURE(SVE_SHA3) + ASMJIT_ARM_FEATURE(SVE_SM4) + ASMJIT_ARM_FEATURE(SYSINSTR128) + ASMJIT_ARM_FEATURE(SYSREG128) + ASMJIT_ARM_FEATURE(THE) + ASMJIT_ARM_FEATURE(TLBIOS) + ASMJIT_ARM_FEATURE(TLBIRANGE) + ASMJIT_ARM_FEATURE(TLBIW) + ASMJIT_ARM_FEATURE(TME) + ASMJIT_ARM_FEATURE(TRF) + ASMJIT_ARM_FEATURE(UAO) + ASMJIT_ARM_FEATURE(VFP_D32) + ASMJIT_ARM_FEATURE(VHE) + ASMJIT_ARM_FEATURE(VMID16) + ASMJIT_ARM_FEATURE(WFXT) + ASMJIT_ARM_FEATURE(XNX) + ASMJIT_ARM_FEATURE(XS) + + #undef ASMJIT_ARM_FEATURE + }; + + static_assert(uint32_t(X86::kMaxValue) < kMaxFeatures, "The number of X86 CPU features cannot exceed CpuFeatures::kMaxFeatures"); + static_assert(uint32_t(ARM::kMaxValue) < kMaxFeatures, "The number of ARM CPU features cannot exceed CpuFeatures::kMaxFeatures"); + + //! \} + + //! \name Members + //! \{ + + Data _data {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG CpuFeatures() noexcept {} + ASMJIT_INLINE_NODEBUG CpuFeatures(const CpuFeatures& other) noexcept = default; + ASMJIT_INLINE_NODEBUG explicit CpuFeatures(const Data& other) noexcept : _data{other._bits} {} + ASMJIT_INLINE_NODEBUG explicit CpuFeatures(Globals::NoInit_) noexcept {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG CpuFeatures& operator=(const CpuFeatures& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG bool operator==(const CpuFeatures& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const CpuFeatures& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns true if there are no features set. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _data.empty(); } + + //! Casts this base class into a derived type `T`. + template + ASMJIT_INLINE_NODEBUG T& data() noexcept { return static_cast(_data); } + + //! Casts this base class into a derived type `T` (const). + template + ASMJIT_INLINE_NODEBUG const T& data() const noexcept { return static_cast(_data); } + + //! Returns CpuFeatures::Data as \ref CpuFeatures::X86. + ASMJIT_INLINE_NODEBUG X86& x86() noexcept { return data(); } + //! Returns CpuFeatures::Data as \ref CpuFeatures::X86 (const). + ASMJIT_INLINE_NODEBUG const X86& x86() const noexcept { return data(); } + + //! Returns CpuFeatures::Data as \ref CpuFeatures::ARM. + ASMJIT_INLINE_NODEBUG ARM& arm() noexcept { return data(); } + //! Returns CpuFeatures::Data as \ref CpuFeatures::ARM (const). + ASMJIT_INLINE_NODEBUG const ARM& arm() const noexcept { return data(); } + + //! Returns all features as array of bitwords (see \ref Support::BitWord). + ASMJIT_INLINE_NODEBUG BitWord* bits() noexcept { return _data.bits(); } + //! Returns all features as array of bitwords (const). + ASMJIT_INLINE_NODEBUG const BitWord* bits() const noexcept { return _data.bits(); } + //! Returns the number of BitWords returned by \ref bits(). + ASMJIT_INLINE_NODEBUG size_t bitWordCount() const noexcept { return _data.bitWordCount(); } + + //! Returns \ref Support::BitVectorIterator, that can be used to iterate over all features efficiently. + ASMJIT_INLINE_NODEBUG Iterator iterator() const noexcept { return _data.iterator(); } + + //! Tests whether the feature `featureId` is present. + template + ASMJIT_INLINE_NODEBUG bool has(const FeatureId& featureId) const noexcept { return _data.has(featureId); } + + //! Tests whether any of the features is present. + template + ASMJIT_INLINE_NODEBUG bool hasAny(Args&&... args) const noexcept { return _data.hasAny(std::forward(args)...); } + + //! Tests whether all features as defined by `other` are present. + ASMJIT_INLINE_NODEBUG bool hasAll(const CpuFeatures& other) const noexcept { return _data.hasAll(other._data); } + + //! \} + + //! \name Manipulation + //! \{ + + //! Clears all features set. + ASMJIT_INLINE_NODEBUG void reset() noexcept { _data.reset(); } + + //! Adds the given CPU `featureId` to the list of features. + template + ASMJIT_INLINE_NODEBUG void add(Args&&... args) noexcept { return _data.add(std::forward(args)...); } + + //! Adds the given CPU `featureId` to the list of features if `condition` is true. + template + ASMJIT_INLINE_NODEBUG void addIf(bool condition, Args&&... args) noexcept { return _data.addIf(condition, std::forward(args)...); } + + //! Removes the given CPU `featureId` from the list of features. + template + ASMJIT_INLINE_NODEBUG void remove(Args&&... args) noexcept { return _data.remove(std::forward(args)...); } + + //! Tests whether this CPU features matches `other`. + ASMJIT_INLINE_NODEBUG bool equals(const CpuFeatures& other) const noexcept { return _data.equals(other._data); } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use CpuFeatures::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const CpuFeatures& other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} +}; + +//! CPU information. +class CpuInfo { +public: + //! \name Members + //! \{ + + //! Architecture. + Arch _arch {}; + //! Sub-architecture. + SubArch _subArch {}; + //! True if the CPU was detected, false if the detection failed or it's not available. + bool _wasDetected {}; + //! Reserved for future use. + uint8_t _reserved {}; + //! CPU family ID. + uint32_t _familyId {}; + //! CPU model ID. + uint32_t _modelId {}; + //! CPU brand ID. + uint32_t _brandId {}; + //! CPU stepping. + uint32_t _stepping {}; + //! Processor type. + uint32_t _processorType {}; + //! Maximum number of addressable IDs for logical processors. + uint32_t _maxLogicalProcessors {}; + //! Cache line size (in bytes). + uint32_t _cacheLineSize {}; + //! Number of hardware threads. + uint32_t _hwThreadCount {}; + + //! CPU vendor string. + FixedString<16> _vendor {}; + //! CPU brand string. + FixedString<64> _brand {}; + //! CPU features. + CpuFeatures _features {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new CpuInfo instance. + ASMJIT_INLINE_NODEBUG CpuInfo() noexcept {} + //! Creates a copy of `other` instance. + ASMJIT_INLINE_NODEBUG CpuInfo(const CpuInfo& other) noexcept = default; + + //! Creates an unitialized `CpuInfo` instance. + ASMJIT_INLINE_NODEBUG explicit CpuInfo(Globals::NoInit_) noexcept + : _features(Globals::NoInit) {}; + + //! \} + + //! \name CPU Information Detection + //! \{ + + //! Returns the host CPU information. + //! + //! \note The returned reference is global - it's setup only once and then shared. + ASMJIT_API static const CpuInfo& host() noexcept; + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment. + ASMJIT_INLINE_NODEBUG CpuInfo& operator=(const CpuInfo& other) noexcept = default; + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Initializes CpuInfo architecture and sub-architecture members to `arch` and `subArch`, respectively. + ASMJIT_INLINE_NODEBUG void initArch(Arch arch, SubArch subArch = SubArch::kUnknown) noexcept { + _arch = arch; + _subArch = subArch; + } + + //! Resets this \ref CpuInfo to a default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = CpuInfo{}; } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the CPU architecture this information relates to. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + + //! Returns the CPU sub-architecture this information relates to. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return _subArch; } + + //! Returns whether the CPU was detected successfully. + //! + //! If the returned value is false it means that AsmJit either failed to detect the CPU or it doesn't have + //! implementation targeting the host architecture and operating system. + ASMJIT_INLINE_NODEBUG bool wasDetected() const noexcept { return _wasDetected; } + + //! Returns the CPU family ID. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Family identifier matches the FamilyId read by using CPUID. + //! - ARM: + //! - Apple - returns Apple Family identifier returned by sysctlbyname("hw.cpufamily"). + ASMJIT_INLINE_NODEBUG uint32_t familyId() const noexcept { return _familyId; } + + //! Returns the CPU model ID. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Model identifier matches the ModelId read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t modelId() const noexcept { return _modelId; } + + //! Returns the CPU brand id. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Brand identifier matches the BrandId read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t brandId() const noexcept { return _brandId; } + + //! Returns the CPU stepping. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Stepping identifier matches the Stepping information read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t stepping() const noexcept { return _stepping; } + + //! Returns the processor type. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Processor type identifier matches the ProcessorType read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t processorType() const noexcept { return _processorType; } + + //! Returns the maximum number of logical processors. + ASMJIT_INLINE_NODEBUG uint32_t maxLogicalProcessors() const noexcept { return _maxLogicalProcessors; } + + //! Returns the size of a CPU cache line. + //! + //! On a multi-architecture system this should return the smallest cache line of all CPUs. + ASMJIT_INLINE_NODEBUG uint32_t cacheLineSize() const noexcept { return _cacheLineSize; } + + //! Returns number of hardware threads available. + ASMJIT_INLINE_NODEBUG uint32_t hwThreadCount() const noexcept { return _hwThreadCount; } + + //! Returns a CPU vendor string. + ASMJIT_INLINE_NODEBUG const char* vendor() const noexcept { return _vendor.str; } + //! Tests whether the CPU vendor string is equal to `s`. + ASMJIT_INLINE_NODEBUG bool isVendor(const char* s) const noexcept { return _vendor.equals(s); } + + //! Returns a CPU brand string. + ASMJIT_INLINE_NODEBUG const char* brand() const noexcept { return _brand.str; } + + //! Returns CPU features. + ASMJIT_INLINE_NODEBUG CpuFeatures& features() noexcept { return _features; } + //! Returns CPU features (const). + ASMJIT_INLINE_NODEBUG const CpuFeatures& features() const noexcept { return _features; } + + //! Tests whether the CPU has the given `feature`. + template + ASMJIT_INLINE_NODEBUG bool hasFeature(const FeatureId& featureId) const noexcept { return _features.has(featureId); } + + //! Adds the given CPU `featureId` to the list of features. + template + ASMJIT_INLINE_NODEBUG void addFeature(Args&&... args) noexcept { return _features.add(std::forward(args)...); } + + //! Removes the given CPU `featureId` from the list of features. + template + ASMJIT_INLINE_NODEBUG void removeFeature(Args&&... args) noexcept { return _features.remove(std::forward(args)...); } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CPUINFO_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/emitter.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..09c72a68eae77f682b75c38d6fd57c1df017c2e5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/emitter.h @@ -0,0 +1,817 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_EMITTER_H_INCLUDED +#define ASMJIT_CORE_EMITTER_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/codeholder.h" +#include "../core/formatter.h" +#include "../core/inst.h" +#include "../core/operand.h" +#include "../core/type.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +class ConstPool; +class FuncFrame; +class FuncArgsAssignment; + +//! Align mode, used by \ref BaseEmitter::align(). +enum class AlignMode : uint8_t { + //! Align executable code. + kCode = 0, + //! Align non-executable code. + kData = 1, + //! Align by a sequence of zeros. + kZero = 2, + + //! Maximum value of `AlignMode`. + kMaxValue = kZero +}; + +//! Emitter type used by \ref BaseEmitter. +enum class EmitterType : uint8_t { + //! Unknown or uninitialized. + kNone = 0, + //! Emitter inherits from \ref BaseAssembler. + kAssembler = 1, + //! Emitter inherits from \ref BaseBuilder. + kBuilder = 2, + //! Emitter inherits from \ref BaseCompiler. + kCompiler = 3, + + //! Maximum value of `EmitterType`. + kMaxValue = kCompiler +}; + +//! Emitter flags, used by \ref BaseEmitter. +enum class EmitterFlags : uint8_t { + //! No flags. + kNone = 0u, + //! Emitter is attached to CodeHolder. + kAttached = 0x01u, + //! The emitter must emit comments. + kLogComments = 0x08u, + //! The emitter has its own \ref Logger (not propagated from \ref CodeHolder). + kOwnLogger = 0x10u, + //! The emitter has its own \ref ErrorHandler (not propagated from \ref CodeHolder). + kOwnErrorHandler = 0x20u, + //! The emitter was finalized. + kFinalized = 0x40u, + //! The emitter was destroyed. + //! + //! This flag is used for a very short time when an emitter is being destroyed by + //! CodeHolder. + kDestroyed = 0x80u +}; +ASMJIT_DEFINE_ENUM_FLAGS(EmitterFlags) + +//! Encoding options. +enum class EncodingOptions : uint32_t { + //! No encoding options. + kNone = 0, + + //! Emit instructions that are optimized for size, if possible. + //! + //! Default: false. + //! + //! X86 Specific + //! ------------ + //! + //! When this option is set it the assembler will try to fix instructions if possible into operation equivalent + //! instructions that take less bytes by taking advantage of implicit zero extension. For example instruction + //! like `mov r64, imm` and `and r64, imm` can be translated to `mov r32, imm` and `and r32, imm` when the + //! immediate constant is lesser than `2^31`. + kOptimizeForSize = 0x00000001u, + + //! Emit optimized code-alignment sequences. + //! + //! Default: false. + //! + //! X86 Specific + //! ------------ + //! + //! Default align sequence used by X86 architecture is one-byte (0x90) opcode that is often shown by disassemblers + //! as NOP. However there are more optimized align sequences for 2-11 bytes that may execute faster on certain CPUs. + //! If this feature is enabled AsmJit will generate specialized sequences for alignment between 2 to 11 bytes. + kOptimizedAlign = 0x00000002u, + + //! Emit jump-prediction hints. + //! + //! Default: false. + //! + //! X86 Specific + //! ------------ + //! + //! Jump prediction is usually based on the direction of the jump. If the jump is backward it is usually predicted as + //! taken; and if the jump is forward it is usually predicted as not-taken. The reason is that loops generally use + //! backward jumps and conditions usually use forward jumps. However this behavior can be overridden by using + //! instruction prefixes. If this option is enabled these hints will be emitted. + //! + //! This feature is disabled by default, because the only processor that used to take into consideration prediction + //! hints was P4. Newer processors implement heuristics for branch prediction and ignore static hints. This means + //! that this feature can be only used for annotation purposes. + kPredictedJumps = 0x00000010u +}; +ASMJIT_DEFINE_ENUM_FLAGS(EncodingOptions) + +//! Diagnostic options are used to tell emitters and their passes to perform diagnostics when emitting or processing +//! user code. These options control validation and extra diagnostics that can be performed by higher level emitters. +//! +//! Instruction Validation +//! ---------------------- +//! +//! \ref BaseAssembler implementation perform by default only basic checks that are necessary to identify all +//! variations of an instruction so the correct encoding can be selected. This is fine for production-ready code +//! as the assembler doesn't have to perform checks that would slow it down. However, sometimes these checks are +//! beneficial especially when the project that uses AsmJit is in a development phase, in which mistakes happen +//! often. To make the experience of using AsmJit seamless it offers validation features that can be controlled +//! by \ref DiagnosticOptions. +//! +//! Compiler Diagnostics +//! -------------------- +//! +//! Diagnostic options work with \ref BaseCompiler passes (precisely with its register allocation pass). These options +//! can be used to enable logging of all operations that the Compiler does. +enum class DiagnosticOptions : uint32_t { + //! No validation options. + kNone = 0, + + //! Perform strict validation in \ref BaseAssembler::emit() implementations. + //! + //! This flag ensures that each instruction is checked before it's encoded into a binary representation. This flag + //! is only relevant for \ref BaseAssembler implementations, but can be set in any other emitter type, in that case + //! if that emitter needs to create an assembler on its own, for the purpose of \ref BaseEmitter::finalize() it + //! would propagate this flag to such assembler so all instructions passed to it are explicitly validated. + //! + //! Default: false. + kValidateAssembler = 0x00000001u, + + //! Perform strict validation in \ref BaseBuilder::emit() and \ref BaseCompiler::emit() implementations. + //! + //! This flag ensures that each instruction is checked before an \ref InstNode representing the instruction is + //! created by \ref BaseBuilder or \ref BaseCompiler. This option could be more useful than \ref kValidateAssembler + //! in cases in which there is an invalid instruction passed to an assembler, which was invalid much earlier, most + //! likely when such instruction was passed to Builder/Compiler. + //! + //! This is a separate option that was introduced, because it's possible to manipulate the instruction stream + //! emitted by \ref BaseBuilder and \ref BaseCompiler - this means that it's allowed to emit invalid instructions + //! (for example with missing operands) that will be fixed later before finalizing it. + //! + //! Default: false. + kValidateIntermediate = 0x00000002u, + + //! Annotate all nodes processed by register allocator (Compiler/RA). + //! + //! \note Annotations don't need debug options, however, some debug options like `kRADebugLiveness` may influence + //! their output (for example the mentioned option would add liveness information to per-instruction annotation). + kRAAnnotate = 0x00000080u, + + //! Debug CFG generation and other related algorithms / operations (Compiler/RA). + kRADebugCFG = 0x00000100u, + + //! Debug liveness analysis (Compiler/RA). + kRADebugLiveness = 0x00000200u, + + //! Debug register allocation assignment (Compiler/RA). + kRADebugAssignment = 0x00000400u, + + //! Debug the removal of code part of unreachable blocks. + kRADebugUnreachable = 0x00000800u, + + //! Enable all debug options (Compiler/RA). + kRADebugAll = 0x0000FF00u, +}; +ASMJIT_DEFINE_ENUM_FLAGS(DiagnosticOptions) + +//! Provides a base foundation to emitting code - specialized by \ref BaseAssembler and \ref BaseBuilder. +class ASMJIT_VIRTAPI BaseEmitter { +public: + ASMJIT_BASE_CLASS(BaseEmitter) + ASMJIT_NONCOPYABLE(BaseEmitter) + + //! \name Members + //! \{ + + //! See \ref EmitterType. + EmitterType _emitterType = EmitterType::kNone; + //! See \ref EmitterFlags. + EmitterFlags _emitterFlags = EmitterFlags::kNone; + //! Instruction alignment. + uint8_t _instructionAlignment = 0u; + //! \cond + uint8_t _reservedBaseEmitter = 0u; + //! \endcond + //! Validation flags in case validation is used. + //! + //! \note Validation flags are specific to the emitter and they are setup at construction time and then never + //! changed. + ValidationFlags _validationFlags = ValidationFlags::kNone; + //! Validation options. + DiagnosticOptions _diagnosticOptions = DiagnosticOptions::kNone; + + //! All supported architectures in a bit-mask, where LSB is the bit with a zero index. + uint64_t _archMask = 0; + + //! Encoding options. + EncodingOptions _encodingOptions = EncodingOptions::kNone; + + //! Forced instruction options, combined with \ref _instOptions by \ref emit(). + InstOptions _forcedInstOptions = InstOptions::kReserved; + //! Internal private data used freely by any emitter. + uint32_t _privateData = 0; + + //! CodeHolder the emitter is attached to. + CodeHolder* _code = nullptr; + //! Attached \ref Logger. + Logger* _logger = nullptr; + //! Attached \ref ErrorHandler. + ErrorHandler* _errorHandler = nullptr; + + //! Describes the target environment, matches \ref CodeHolder::environment(). + Environment _environment {}; + //! Native GP register signature and signature related information. + OperandSignature _gpSignature {}; + + //! Emitter state that can be used to specify options and inline comment of a next node or instruction. + struct State { + InstOptions options; + RegOnly extraReg; + const char* comment; + }; + + //! Next instruction options (affects the next instruction). + InstOptions _instOptions = InstOptions::kNone; + //! Extra register (op-mask {k} on AVX-512) (affects the next instruction). + RegOnly _extraReg {}; + //! Inline comment of the next instruction (affects the next instruction). + const char* _inlineComment = nullptr; + + //! Function callbacks used by emitter implementation. + //! + //! These are typically shared between Assembler/Builder/Compiler of a single backend. + struct Funcs { + typedef Error (ASMJIT_CDECL* EmitProlog)(BaseEmitter* emitter, const FuncFrame& frame); + typedef Error (ASMJIT_CDECL* EmitEpilog)(BaseEmitter* emitter, const FuncFrame& frame); + typedef Error (ASMJIT_CDECL* EmitArgsAssignment)(BaseEmitter* emitter, const FuncFrame& frame, const FuncArgsAssignment& args); + + typedef Error (ASMJIT_CDECL* FormatInstruction)( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + const BaseInst& inst, const Operand_* operands, size_t opCount) ASMJIT_NOEXCEPT_TYPE; + + typedef Error (ASMJIT_CDECL* ValidateFunc)(const BaseInst& inst, const Operand_* operands, size_t opCount, ValidationFlags validationFlags) ASMJIT_NOEXCEPT_TYPE; + + //! Emit prolog implementation. + EmitProlog emitProlog; + //! Emit epilog implementation. + EmitEpilog emitEpilog; + //! Emit arguments assignment implementation. + EmitArgsAssignment emitArgsAssignment; + //! Instruction formatter implementation. + FormatInstruction formatInstruction; + //! Instruction validation implementation. + ValidateFunc validate; + + //! Resets all functions to nullptr. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + emitProlog = nullptr; + emitEpilog = nullptr; + emitArgsAssignment = nullptr; + validate = nullptr; + } + }; + + Funcs _funcs {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit BaseEmitter(EmitterType emitterType) noexcept; + ASMJIT_API virtual ~BaseEmitter() noexcept; + + //! \} + + //! \name Cast + //! \{ + + template + ASMJIT_INLINE_NODEBUG T* as() noexcept { return reinterpret_cast(this); } + + template + ASMJIT_INLINE_NODEBUG const T* as() const noexcept { return reinterpret_cast(this); } + + //! \} + + //! \name Emitter Type & Flags + //! \{ + + //! Returns the type of this emitter, see `EmitterType`. + ASMJIT_INLINE_NODEBUG EmitterType emitterType() const noexcept { return _emitterType; } + //! Returns emitter flags , see `Flags`. + ASMJIT_INLINE_NODEBUG EmitterFlags emitterFlags() const noexcept { return _emitterFlags; } + + //! Tests whether the emitter inherits from `BaseAssembler`. + ASMJIT_INLINE_NODEBUG bool isAssembler() const noexcept { return _emitterType == EmitterType::kAssembler; } + //! Tests whether the emitter inherits from `BaseBuilder`. + //! + //! \note Both Builder and Compiler emitters would return `true`. + ASMJIT_INLINE_NODEBUG bool isBuilder() const noexcept { return uint32_t(_emitterType) >= uint32_t(EmitterType::kBuilder); } + //! Tests whether the emitter inherits from `BaseCompiler`. + ASMJIT_INLINE_NODEBUG bool isCompiler() const noexcept { return _emitterType == EmitterType::kCompiler; } + + //! Tests whether the emitter has the given `flag` enabled. + ASMJIT_INLINE_NODEBUG bool hasEmitterFlag(EmitterFlags flag) const noexcept { return Support::test(_emitterFlags, flag); } + //! Tests whether the emitter is finalized. + ASMJIT_INLINE_NODEBUG bool isFinalized() const noexcept { return hasEmitterFlag(EmitterFlags::kFinalized); } + //! Tests whether the emitter is destroyed (only used during destruction). + ASMJIT_INLINE_NODEBUG bool isDestroyed() const noexcept { return hasEmitterFlag(EmitterFlags::kDestroyed); } + + //! \} + + //! \cond INTERNAL + //! \name Internal Functions + //! \{ + + ASMJIT_INLINE_NODEBUG void _addEmitterFlags(EmitterFlags flags) noexcept { _emitterFlags |= flags; } + ASMJIT_INLINE_NODEBUG void _clearEmitterFlags(EmitterFlags flags) noexcept { _emitterFlags &= _emitterFlags & ~flags; } + + //! \} + //! \endcond + + //! \name Target Information + //! \{ + + //! Returns the CodeHolder this emitter is attached to. + ASMJIT_INLINE_NODEBUG CodeHolder* code() const noexcept { return _code; } + + //! Returns the target environment. + //! + //! The returned \ref Environment reference matches \ref CodeHolder::environment(). + ASMJIT_INLINE_NODEBUG const Environment& environment() const noexcept { return _environment; } + + //! Tests whether the target architecture is 32-bit. + ASMJIT_INLINE_NODEBUG bool is32Bit() const noexcept { return environment().is32Bit(); } + //! Tests whether the target architecture is 64-bit. + ASMJIT_INLINE_NODEBUG bool is64Bit() const noexcept { return environment().is64Bit(); } + + //! Returns the target architecture type. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return environment().arch(); } + //! Returns the target architecture sub-type. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return environment().subArch(); } + + //! Returns the target architecture's GP register size (4 or 8 bytes). + ASMJIT_INLINE_NODEBUG uint32_t registerSize() const noexcept { return environment().registerSize(); } + + //! Returns instruction alignment. + //! + //! The following values are returned based on the target architecture: + //! - X86 and X86_64 - instruction alignment is 1 + //! - AArch32 - instruction alignment is 4 in A32 mode and 2 in THUMB mode. + //! - AArch64 - instruction alignment is 4 + ASMJIT_INLINE_NODEBUG uint32_t instructionAlignment() const noexcept { return _instructionAlignment; } + + //! \} + + //! \name Initialization & Finalization + //! \{ + + //! Tests whether the emitter is initialized (i.e. attached to \ref CodeHolder). + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _code != nullptr; } + + //! Finalizes this emitter. + //! + //! Materializes the content of the emitter by serializing it to the attached \ref CodeHolder through an architecture + //! specific \ref BaseAssembler. This function won't do anything if the emitter inherits from \ref BaseAssembler as + //! assemblers emit directly to a \ref CodeBuffer held by \ref CodeHolder. However, if this is an emitter that + //! inherits from \ref BaseBuilder or \ref BaseCompiler then these emitters need the materialization phase as they + //! store their content in a representation not visible to \ref CodeHolder. + ASMJIT_API virtual Error finalize(); + + //! \} + + //! \name Logging + //! \{ + + //! Tests whether the emitter has a logger. + ASMJIT_INLINE_NODEBUG bool hasLogger() const noexcept { return _logger != nullptr; } + + //! Tests whether the emitter has its own logger. + //! + //! Own logger means that it overrides the possible logger that may be used by \ref CodeHolder this emitter is + //! attached to. + ASMJIT_INLINE_NODEBUG bool hasOwnLogger() const noexcept { return hasEmitterFlag(EmitterFlags::kOwnLogger); } + + //! Returns the logger this emitter uses. + //! + //! The returned logger is either the emitter's own logger or it's logger used by \ref CodeHolder this emitter + //! is attached to. + ASMJIT_INLINE_NODEBUG Logger* logger() const noexcept { return _logger; } + + //! Sets or resets the logger of the emitter. + //! + //! If the `logger` argument is non-null then the logger will be considered emitter's own logger, see \ref + //! hasOwnLogger() for more details. If the given `logger` is null then the emitter will automatically use logger + //! that is attached to the \ref CodeHolder this emitter is attached to. + ASMJIT_API void setLogger(Logger* logger) noexcept; + + //! Resets the logger of this emitter. + //! + //! The emitter will bail to using a logger attached to \ref CodeHolder this emitter is attached to, or no logger + //! at all if \ref CodeHolder doesn't have one. + ASMJIT_INLINE_NODEBUG void resetLogger() noexcept { return setLogger(nullptr); } + + //! \} + + //! \name Error Handling + //! \{ + + //! Tests whether the emitter has an error handler attached. + ASMJIT_INLINE_NODEBUG bool hasErrorHandler() const noexcept { return _errorHandler != nullptr; } + + //! Tests whether the emitter has its own error handler. + //! + //! Own error handler means that it overrides the possible error handler that may be used by \ref CodeHolder this + //! emitter is attached to. + ASMJIT_INLINE_NODEBUG bool hasOwnErrorHandler() const noexcept { return hasEmitterFlag(EmitterFlags::kOwnErrorHandler); } + + //! Returns the error handler this emitter uses. + //! + //! The returned error handler is either the emitter's own error handler or it's error handler used by + //! \ref CodeHolder this emitter is attached to. + ASMJIT_INLINE_NODEBUG ErrorHandler* errorHandler() const noexcept { return _errorHandler; } + + //! Sets or resets the error handler of the emitter. + ASMJIT_API void setErrorHandler(ErrorHandler* errorHandler) noexcept; + + //! Resets the error handler. + ASMJIT_INLINE_NODEBUG void resetErrorHandler() noexcept { setErrorHandler(nullptr); } + + //! Handles the given error in the following way: + //! 1. If the emitter has \ref ErrorHandler attached, it calls its \ref ErrorHandler::handleError() member function + //! first, and then returns the error. The `handleError()` function may throw. + //! 2. if the emitter doesn't have \ref ErrorHandler, the error is simply returned. + ASMJIT_API Error reportError(Error err, const char* message = nullptr); + + //! \} + + //! \name Encoding Options + //! \{ + + //! Returns encoding options. + ASMJIT_INLINE_NODEBUG EncodingOptions encodingOptions() const noexcept { return _encodingOptions; } + //! Tests whether the encoding `option` is set. + ASMJIT_INLINE_NODEBUG bool hasEncodingOption(EncodingOptions option) const noexcept { return Support::test(_encodingOptions, option); } + + //! Enables the given encoding `options`. + ASMJIT_INLINE_NODEBUG void addEncodingOptions(EncodingOptions options) noexcept { _encodingOptions |= options; } + //! Disables the given encoding `options`. + ASMJIT_INLINE_NODEBUG void clearEncodingOptions(EncodingOptions options) noexcept { _encodingOptions &= ~options; } + + //! \} + + //! \name Diagnostic Options + //! \{ + + //! Returns the emitter's diagnostic options. + ASMJIT_INLINE_NODEBUG DiagnosticOptions diagnosticOptions() const noexcept { return _diagnosticOptions; } + + //! Tests whether the given `option` is present in the emitter's diagnostic options. + ASMJIT_INLINE_NODEBUG bool hasDiagnosticOption(DiagnosticOptions option) const noexcept { return Support::test(_diagnosticOptions, option); } + + //! Activates the given diagnostic `options`. + //! + //! This function is used to activate explicit validation options that will be then used by all emitter + //! implementations. There are in general two possibilities: + //! + //! - Architecture specific assembler is used. In this case a \ref DiagnosticOptions::kValidateAssembler can be + //! used to turn on explicit validation that will be used before an instruction is emitted. This means that + //! internally an extra step will be performed to make sure that the instruction is correct. This is needed, + //! because by default assemblers prefer speed over strictness. + //! + //! This option should be used in debug builds as it's pretty expensive. + //! + //! - Architecture specific builder or compiler is used. In this case the user can turn on + //! \ref DiagnosticOptions::kValidateIntermediate option that adds explicit validation step before the Builder + //! or Compiler creates an \ref InstNode to represent an emitted instruction. Error will be returned if the + //! instruction is ill-formed. In addition, also \ref DiagnosticOptions::kValidateAssembler can be used, which + //! would not be consumed by Builder / Compiler directly, but it would be propagated to an architecture specific + //! \ref BaseAssembler implementation it creates during \ref BaseEmitter::finalize(). + ASMJIT_API void addDiagnosticOptions(DiagnosticOptions options) noexcept; + + //! Deactivates the given validation `options`. + //! + //! See \ref addDiagnosticOptions() and \ref DiagnosticOptions for more details. + ASMJIT_API void clearDiagnosticOptions(DiagnosticOptions options) noexcept; + + //! \} + + //! \name Instruction Options + //! \{ + + //! Returns forced instruction options. + //! + //! Forced instruction options are merged with next instruction options before the instruction is encoded. These + //! options have some bits reserved that are used by error handling, logging, and instruction validation purposes. + //! Other options are globals that affect each instruction. + ASMJIT_INLINE_NODEBUG InstOptions forcedInstOptions() const noexcept { return _forcedInstOptions; } + + //! Returns options of the next instruction. + ASMJIT_INLINE_NODEBUG InstOptions instOptions() const noexcept { return _instOptions; } + //! Returns options of the next instruction. + ASMJIT_INLINE_NODEBUG void setInstOptions(InstOptions options) noexcept { _instOptions = options; } + //! Adds options of the next instruction. + ASMJIT_INLINE_NODEBUG void addInstOptions(InstOptions options) noexcept { _instOptions |= options; } + //! Resets options of the next instruction. + ASMJIT_INLINE_NODEBUG void resetInstOptions() noexcept { _instOptions = InstOptions::kNone; } + + //! Tests whether the extra register operand is valid. + ASMJIT_INLINE_NODEBUG bool hasExtraReg() const noexcept { return _extraReg.isReg(); } + //! Returns an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG const RegOnly& extraReg() const noexcept { return _extraReg; } + //! Sets an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG void setExtraReg(const BaseReg& reg) noexcept { _extraReg.init(reg); } + //! Sets an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG void setExtraReg(const RegOnly& reg) noexcept { _extraReg.init(reg); } + //! Resets an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG void resetExtraReg() noexcept { _extraReg.reset(); } + + //! Returns comment/annotation of the next instruction. + ASMJIT_INLINE_NODEBUG const char* inlineComment() const noexcept { return _inlineComment; } + //! Sets comment/annotation of the next instruction. + //! + //! \note This string is set back to null by `_emit()`, but until that it has to remain valid as the Emitter is not + //! required to make a copy of it (and it would be slow to do that for each instruction). + ASMJIT_INLINE_NODEBUG void setInlineComment(const char* s) noexcept { _inlineComment = s; } + //! Resets the comment/annotation to nullptr. + ASMJIT_INLINE_NODEBUG void resetInlineComment() noexcept { _inlineComment = nullptr; } + + //! \} + + //! \name Emitter State + //! \{ + + //! Resets the emitter state, which contains instruction options, extra register, and inline comment. + //! + //! Emitter can have a state that describes instruction options and extra register used by the instruction. Most + //! instructions don't need nor use the state, however, if an instruction uses a prefix such as REX or REP prefix, + //! which is set explicitly, then the state would contain it. This allows to mimic the syntax of assemblers such + //! as X86. For example `rep().movs(...)` would map to a `REP MOVS` instuction on X86. The same applies to various + //! hints and the use of a mask register in AVX-512 mode. + ASMJIT_INLINE_NODEBUG void resetState() noexcept { + resetInstOptions(); + resetExtraReg(); + resetInlineComment(); + } + + //! \cond INTERNAL + + //! Grabs the current emitter state and resets the emitter state at the same time, returning the state the emitter + //! had before the state was reset. + ASMJIT_INLINE_NODEBUG State _grabState() noexcept { + State s{_instOptions | _forcedInstOptions, _extraReg, _inlineComment}; + resetState(); + return s; + } + //! \endcond + + //! \} + + //! \name Sections + //! \{ + + //! Switches the given `section`. + //! + //! Once switched, everything is added to the given `section`. + ASMJIT_API virtual Error section(Section* section); + + //! \} + + //! \name Labels + //! \{ + + //! Creates a new label. + ASMJIT_API virtual Label newLabel(); + //! Creates a new named label. + ASMJIT_API virtual Label newNamedLabel(const char* name, size_t nameSize = SIZE_MAX, LabelType type = LabelType::kGlobal, uint32_t parentId = Globals::kInvalidId); + + //! Creates a new anonymous label with a name, which can only be used for debugging purposes. + ASMJIT_INLINE_NODEBUG Label newAnonymousLabel(const char* name, size_t nameSize = SIZE_MAX) { return newNamedLabel(name, nameSize, LabelType::kAnonymous); } + //! Creates a new external label. + ASMJIT_INLINE_NODEBUG Label newExternalLabel(const char* name, size_t nameSize = SIZE_MAX) { return newNamedLabel(name, nameSize, LabelType::kExternal); } + + //! Returns `Label` by `name`. + //! + //! Returns invalid Label in case that the name is invalid or label was not found. + //! + //! \note This function doesn't trigger ErrorHandler in case the name is invalid or no such label exist. You must + //! always check the validity of the `Label` returned. + ASMJIT_API Label labelByName(const char* name, size_t nameSize = SIZE_MAX, uint32_t parentId = Globals::kInvalidId) noexcept; + + //! Binds the `label` to the current position of the current section. + //! + //! \note Attempt to bind the same label multiple times will return an error. + ASMJIT_API virtual Error bind(const Label& label); + + //! Tests whether the label `id` is valid (i.e. registered). + ASMJIT_API bool isLabelValid(uint32_t labelId) const noexcept; + //! Tests whether the `label` is valid (i.e. registered). + ASMJIT_INLINE_NODEBUG bool isLabelValid(const Label& label) const noexcept { return isLabelValid(label.id()); } + + //! \} + + //! \name Emit + //! \{ + + // NOTE: These `emit()` helpers are designed to address a code-bloat generated by C++ compilers to call a function + // having many arguments. Each parameter to `_emit()` requires some code to pass it, which means that if we default + // to 5 arguments in `_emit()` and instId the C++ compiler would have to generate a virtual function call having 5 + // parameters and additional `this` argument, which is quite a lot. Since by default most instructions have 2 to 3 + // operands it's better to introduce helpers that pass from 0 to 6 operands that help to reduce the size of emit(...) + // function call. + + //! Emits an instruction (internal). + ASMJIT_API Error _emitI(InstId instId); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3, const Operand_& o4); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3, const Operand_& o4, const Operand_& o5); + + //! Emits an instruction `instId` with the given `operands`. + //! + //! This is the most universal way of emitting code, which accepts an instruction identifier and instruction + //! operands. This is called an "unchecked" API as emit doesn't provide any type checks at compile-time. This + //! allows to emit instruction with just \ref Operand instances, which could be handy in some cases - for + //! example emitting generic code where you don't know whether some operand is register, memory, or immediate. + template + ASMJIT_INLINE_NODEBUG Error emit(InstId instId, Args&&... operands) { + return _emitI(instId, Support::ForwardOp::forward(operands)...); + } + + //! Similar to \ref emit(), but uses array of `operands` instead. + ASMJIT_INLINE_NODEBUG Error emitOpArray(InstId instId, const Operand_* operands, size_t opCount) { + return _emitOpArray(instId, operands, opCount); + } + + //! Similar to \ref emit(), but emits instruction with both instruction options and extra register, followed + //! by an array of `operands`. + ASMJIT_FORCE_INLINE Error emitInst(const BaseInst& inst, const Operand_* operands, size_t opCount) { + setInstOptions(inst.options()); + setExtraReg(inst.extraReg()); + return _emitOpArray(inst.id(), operands, opCount); + } + + //! \} + + //! \cond INTERNAL + //! \name Emit Internals + //! \{ + + //! Emits an instruction - all 6 operands must be defined. + ASMJIT_API virtual Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* oExt); + //! Emits instruction having operands stored in array. + ASMJIT_API virtual Error _emitOpArray(InstId instId, const Operand_* operands, size_t opCount); + + //! \} + //! \endcond + + //! \name Emit Utilities + //! \{ + + //! Emits a function prolog described by the given function `frame`. + ASMJIT_API Error emitProlog(const FuncFrame& frame); + //! Emits a function epilog described by the given function `frame`. + ASMJIT_API Error emitEpilog(const FuncFrame& frame); + //! Emits code that reassigns function `frame` arguments to the given `args`. + ASMJIT_API Error emitArgsAssignment(const FuncFrame& frame, const FuncArgsAssignment& args); + + //! \} + + //! \name Align + //! \{ + + //! Aligns the current CodeBuffer position to the `alignment` specified. + //! + //! The sequence that is used to fill the gap between the aligned location and the current location depends on the + //! align `mode`, see \ref AlignMode. The `alignment` argument specifies alignment in bytes, so for example when + //! it's `32` it means that the code buffer will be aligned to `32` bytes. + ASMJIT_API virtual Error align(AlignMode alignMode, uint32_t alignment); + + //! \} + + //! \name Embed + //! \{ + + //! Embeds raw data into the \ref CodeBuffer. + ASMJIT_API virtual Error embed(const void* data, size_t dataSize); + + //! Embeds a typed data array. + //! + //! This is the most flexible function for embedding data as it allows to: + //! + //! - Assign a `typeId` to the data, so the emitter knows the type of items stored in `data`. Binary data should + //! use \ref TypeId::kUInt8. + //! + //! - Repeat the given data `repeatCount` times, so the data can be used as a fill pattern for example, or as a + //! pattern used by SIMD instructions. + ASMJIT_API virtual Error embedDataArray(TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1); + + //! Embeds int8_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt8(int8_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt8, &value, 1, repeatCount); } + //! Embeds uint8_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt8(uint8_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt8, &value, 1, repeatCount); } + //! Embeds int16_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt16(int16_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt16, &value, 1, repeatCount); } + //! Embeds uint16_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt16(uint16_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt16, &value, 1, repeatCount); } + //! Embeds int32_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt32(int32_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt32, &value, 1, repeatCount); } + //! Embeds uint32_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt32(uint32_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt32, &value, 1, repeatCount); } + //! Embeds int64_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt64(int64_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt64, &value, 1, repeatCount); } + //! Embeds uint64_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt64(uint64_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt64, &value, 1, repeatCount); } + //! Embeds a floating point `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedFloat(float value, size_t repeatCount = 1) { return embedDataArray(TypeId(TypeUtils::TypeIdOfT::kTypeId), &value, 1, repeatCount); } + //! Embeds a floating point `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedDouble(double value, size_t repeatCount = 1) { return embedDataArray(TypeId(TypeUtils::TypeIdOfT::kTypeId), &value, 1, repeatCount); } + + //! Embeds a constant pool at the current offset by performing the following: + //! 1. Aligns by using AlignMode::kData to the minimum `pool` alignment. + //! 2. Binds the ConstPool label so it's bound to an aligned location. + //! 3. Emits ConstPool content. + ASMJIT_API virtual Error embedConstPool(const Label& label, const ConstPool& pool); + + //! Embeds an absolute `label` address as data. + //! + //! The `dataSize` is an optional argument that can be used to specify the size of the address data. If it's zero + //! (default) the address size is deduced from the target architecture (either 4 or 8 bytes). + ASMJIT_API virtual Error embedLabel(const Label& label, size_t dataSize = 0); + + //! Embeds a delta (distance) between the `label` and `base` calculating it as `label - base`. This function was + //! designed to make it easier to embed lookup tables where each index is a relative distance of two labels. + ASMJIT_API virtual Error embedLabelDelta(const Label& label, const Label& base, size_t dataSize = 0); + + //! \} + + //! \name Comment + //! \{ + + //! Emits a comment stored in `data` with an optional `size` parameter. + ASMJIT_API virtual Error comment(const char* data, size_t size = SIZE_MAX); + + //! Emits a formatted comment specified by `fmt` and variable number of arguments. + ASMJIT_API Error commentf(const char* fmt, ...); + //! Emits a formatted comment specified by `fmt` and `ap`. + ASMJIT_API Error commentv(const char* fmt, va_list ap); + + //! \} + + //! \name Events + //! \{ + + //! Called after the emitter was attached to `CodeHolder`. + ASMJIT_API virtual Error onAttach(CodeHolder* ASMJIT_NONNULL(code)) noexcept; + //! Called after the emitter was detached from `CodeHolder`. + ASMJIT_API virtual Error onDetach(CodeHolder* ASMJIT_NONNULL(code)) noexcept; + + //! Called when \ref CodeHolder has updated an important setting, which involves the following: + //! + //! - \ref Logger has been changed (\ref CodeHolder::setLogger() has been called). + //! + //! - \ref ErrorHandler has been changed (\ref CodeHolder::setErrorHandler() has been called). + //! + //! This function ensures that the settings are properly propagated from \ref CodeHolder to the emitter. + //! + //! \note This function is virtual and can be overridden, however, if you do so, always call \ref + //! BaseEmitter::onSettingsUpdated() within your own implementation to ensure that the emitter is + //! in a consistent state. + ASMJIT_API virtual void onSettingsUpdated() noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_EMITTER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/environment.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..c3678dc65bca45dd27a5b903026bb779400674d8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/environment.h @@ -0,0 +1,534 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ENVIRONMENT_H_INCLUDED +#define ASMJIT_CORE_ENVIRONMENT_H_INCLUDED + +#include "../core/archtraits.h" + +#if defined(__APPLE__) + #include +#endif + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Vendor. +//! +//! \note AsmJit doesn't use vendor information at the moment. It's provided for future use, if required. +enum class Vendor : uint8_t { + //! Unknown or uninitialized platform vendor. + kUnknown = 0, + + //! Maximum value of `Vendor`. + kMaxValue = kUnknown, + + //! Platform vendor detected at compile-time. + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#else + kUnknown +#endif +}; + +//! Platform - runtime environment or operating system. +enum class Platform : uint8_t { + //! Unknown or uninitialized platform. + kUnknown = 0, + + //! Windows OS. + kWindows, + + //! Other platform that is not Windows, most likely POSIX based. + kOther, + + //! Linux OS. + kLinux, + //! GNU/Hurd OS. + kHurd, + + //! FreeBSD OS. + kFreeBSD, + //! OpenBSD OS. + kOpenBSD, + //! NetBSD OS. + kNetBSD, + //! DragonFly BSD OS. + kDragonFlyBSD, + + //! Haiku OS. + kHaiku, + + //! Apple OSX. + kOSX, + //! Apple iOS. + kIOS, + //! Apple TVOS. + kTVOS, + //! Apple WatchOS. + kWatchOS, + + //! Emscripten platform. + kEmscripten, + + //! Maximum value of `Platform`. + kMaxValue = kEmscripten, + + //! Platform detected at compile-time (platform of the host). + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#elif defined(__EMSCRIPTEN__) + kEmscripten +#elif defined(_WIN32) + kWindows +#elif defined(__linux__) + kLinux +#elif defined(__gnu_hurd__) + kHurd +#elif defined(__FreeBSD__) + kFreeBSD +#elif defined(__OpenBSD__) + kOpenBSD +#elif defined(__NetBSD__) + kNetBSD +#elif defined(__DragonFly__) + kDragonFlyBSD +#elif defined(__HAIKU__) + kHaiku +#elif defined(__APPLE__) && TARGET_OS_OSX + kOSX +#elif defined(__APPLE__) && TARGET_OS_TV + kTVOS +#elif defined(__APPLE__) && TARGET_OS_WATCH + kWatchOS +#elif defined(__APPLE__) && TARGET_OS_IPHONE + kIOS +#else + kOther +#endif +}; + +//! Platform ABI (application binary interface). +enum class PlatformABI : uint8_t { + //! Unknown or uninitialized environment. + kUnknown = 0, + //! Microsoft ABI. + kMSVC, + //! GNU ABI. + kGNU, + //! Android Environment / ABI. + kAndroid, + //! Cygwin ABI. + kCygwin, + //! Darwin ABI. + kDarwin, + + //! Maximum value of `PlatformABI`. + kMaxValue, + + //! Host ABI detected at compile-time. + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#elif defined(_MSC_VER) + kMSVC +#elif defined(__CYGWIN__) + kCygwin +#elif defined(__MINGW32__) || defined(__GLIBC__) + kGNU +#elif defined(__ANDROID__) + kAndroid +#elif defined(__APPLE__) + kDarwin +#else + kUnknown +#endif +}; + +//! Floating point ABI (ARM). +enum class FloatABI : uint8_t { + kHardFloat = 0, + kSoftFloat, + + kHost = +#if ASMJIT_ARCH_ARM == 32 && defined(__SOFTFP__) + kSoftFloat +#else + kHardFloat +#endif +}; + +//! Object format. +//! +//! \note AsmJit doesn't really use anything except \ref ObjectFormat::kUnknown and \ref ObjectFormat::kJIT at +//! the moment. Object file formats are provided for future extensibility and a possibility to generate object +//! files at some point. +enum class ObjectFormat : uint8_t { + //! Unknown or uninitialized object format. + kUnknown = 0, + + //! JIT code generation object, most likely \ref JitRuntime or a custom + //! \ref Target implementation. + kJIT, + + //! Executable and linkable format (ELF). + kELF, + //! Common object file format. + kCOFF, + //! Extended COFF object format. + kXCOFF, + //! Mach object file format. + kMachO, + + //! Maximum value of `ObjectFormat`. + kMaxValue +}; + +//! Represents an environment, which is usually related to a \ref Target. +//! +//! Environment has usually an 'arch-subarch-vendor-os-abi' format, which is sometimes called "Triple" (historically +//! it used to be 3 only parts) or "Tuple", which is a convention used by Debian Linux. +//! +//! AsmJit doesn't support all possible combinations or architectures and ABIs, however, it models the environment +//! similarly to other compilers for future extensibility. +class Environment { +public: + //! \name Members + //! \{ + + //! Architecture. + Arch _arch = Arch::kUnknown; + //! Sub-architecture type. + SubArch _subArch = SubArch::kUnknown; + //! Vendor type. + Vendor _vendor = Vendor::kUnknown; + //! Platform. + Platform _platform = Platform::kUnknown; + //! Platform ABI. + PlatformABI _platformABI = PlatformABI::kUnknown; + //! Object format. + ObjectFormat _objectFormat = ObjectFormat::kUnknown; + //! Floating point ABI. + FloatABI _floatABI = FloatABI::kHardFloat; + //! Reserved for future use, must be zero. + uint8_t _reserved = 0; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default initialized environment (all values either unknown or set to safe defaults). + ASMJIT_INLINE_NODEBUG constexpr Environment() noexcept = default; + //! Creates a copy of `other` instance. + ASMJIT_INLINE_NODEBUG constexpr Environment(const Environment& other) noexcept = default; + + //! Creates \ref Environment initialized to `arch`, `subArch`, `vendor`, `platform`, `platformABI`, `objectFormat`, + //! and `floatABI`. + ASMJIT_INLINE_NODEBUG constexpr explicit Environment( + Arch arch, + SubArch subArch = SubArch::kUnknown, + Vendor vendor = Vendor::kUnknown, + Platform platform = Platform::kUnknown, + PlatformABI platformABI = PlatformABI::kUnknown, + ObjectFormat objectFormat = ObjectFormat::kUnknown, + FloatABI floatABI = FloatABI::kHardFloat) noexcept + : _arch(arch), + _subArch(subArch), + _vendor(vendor), + _platform(platform), + _platformABI(platformABI), + _objectFormat(objectFormat), + _floatABI(floatABI) {} + + //! Returns the host environment constructed from preprocessor macros defined by the compiler. + //! + //! The returned environment should precisely match the target host architecture, sub-architecture, platform, + //! and ABI. + static ASMJIT_INLINE_NODEBUG Environment host() noexcept { + return Environment(Arch::kHost, SubArch::kHost, Vendor::kHost, Platform::kHost, PlatformABI::kHost, ObjectFormat::kUnknown, FloatABI::kHost); + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Environment& operator=(const Environment& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG bool operator==(const Environment& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const Environment& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the environment is not set up. + //! + //! Returns true if all members are zero, and thus unknown. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { + // Unfortunately compilers won't optimize fields are checked one by one... + return _packed() == 0; + } + + //! Tests whether the environment is initialized, which means it must have + //! a valid architecture. + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { + return _arch != Arch::kUnknown; + } + + ASMJIT_INLINE_NODEBUG uint64_t _packed() const noexcept { + uint64_t x; + memcpy(&x, this, 8); + return x; + } + + //! Resets all members of the environment to zero / unknown. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = Environment{}; } + + //! Tests whether this environment is equal to `other`. + ASMJIT_INLINE_NODEBUG bool equals(const Environment& other) const noexcept { return _packed() == other._packed(); } + + //! Returns the architecture. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + //! Returns the sub-architecture. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return _subArch; } + //! Returns vendor. + ASMJIT_INLINE_NODEBUG Vendor vendor() const noexcept { return _vendor; } + //! Returns target's platform or operating system. + ASMJIT_INLINE_NODEBUG Platform platform() const noexcept { return _platform; } + //! Returns target's ABI. + ASMJIT_INLINE_NODEBUG PlatformABI platformABI() const noexcept { return _platformABI; } + //! Returns target's object format. + ASMJIT_INLINE_NODEBUG ObjectFormat objectFormat() const noexcept { return _objectFormat; } + //! Returns floating point ABI. + ASMJIT_INLINE_NODEBUG FloatABI floatABI() const noexcept { return _floatABI; } + + //! Initializes \ref Environment to `arch`, `subArch`, `vendor`, `platform`, `platformABI`, `objectFormat`, + //! and `floatABI`. + inline void init( + Arch arch, + SubArch subArch = SubArch::kUnknown, + Vendor vendor = Vendor::kUnknown, + Platform platform = Platform::kUnknown, + PlatformABI platformABI = PlatformABI::kUnknown, + ObjectFormat objectFormat = ObjectFormat::kUnknown, + FloatABI floatABI = FloatABI::kHardFloat) noexcept { + + _arch = arch; + _subArch = subArch; + _vendor = vendor; + _platform = platform; + _platformABI = platformABI; + _objectFormat = objectFormat; + _floatABI = floatABI; + _reserved = 0; + } + + //! Tests whether this environment describes a 32-bit X86. + ASMJIT_INLINE_NODEBUG bool isArchX86() const noexcept { return _arch == Arch::kX86; } + //! Tests whether this environment describes a 64-bit X86. + ASMJIT_INLINE_NODEBUG bool isArchX64() const noexcept { return _arch == Arch::kX64; } + //! Tests whether this environment describes a 32-bit ARM. + ASMJIT_INLINE_NODEBUG bool isArchARM() const noexcept { return isArchARM(_arch); } + //! Tests whether this environment describes a 32-bit ARM in THUMB mode. + ASMJIT_INLINE_NODEBUG bool isArchThumb() const noexcept { return isArchThumb(_arch); } + //! Tests whether this environment describes a 64-bit X86. + ASMJIT_INLINE_NODEBUG bool isArchAArch64() const noexcept { return isArchAArch64(_arch); } + //! Tests whether this environment describes a 32-bit MIPS. + ASMJIT_INLINE_NODEBUG bool isArchMIPS32() const noexcept { return isArchMIPS32(_arch); } + //! Tests whether this environment describes a 64-bit MIPS. + ASMJIT_INLINE_NODEBUG bool isArchMIPS64() const noexcept { return isArchMIPS64(_arch); } + //! Tests whether this environment describes a 32-bit RISC-V. + ASMJIT_INLINE_NODEBUG bool isArchRISCV32() const noexcept { return _arch == Arch::kRISCV32; } + //! Tests whether this environment describes a 64-bit RISC-V. + ASMJIT_INLINE_NODEBUG bool isArchRISCV64() const noexcept { return _arch == Arch::kRISCV64; } + + //! Tests whether the architecture is 32-bit. + ASMJIT_INLINE_NODEBUG bool is32Bit() const noexcept { return is32Bit(_arch); } + //! Tests whether the architecture is 64-bit. + ASMJIT_INLINE_NODEBUG bool is64Bit() const noexcept { return is64Bit(_arch); } + + //! Tests whether the architecture is little endian. + ASMJIT_INLINE_NODEBUG bool isLittleEndian() const noexcept { return isLittleEndian(_arch); } + //! Tests whether the architecture is big endian. + ASMJIT_INLINE_NODEBUG bool isBigEndian() const noexcept { return isBigEndian(_arch); } + + //! Tests whether this architecture is of X86 family. + ASMJIT_INLINE_NODEBUG bool isFamilyX86() const noexcept { return isFamilyX86(_arch); } + //! Tests whether this architecture family is ARM, THUMB, or AArch64. + ASMJIT_INLINE_NODEBUG bool isFamilyARM() const noexcept { return isFamilyARM(_arch); } + //! Tests whether this architecture family is AArch32 (ARM or THUMB). + ASMJIT_INLINE_NODEBUG bool isFamilyAArch32() const noexcept { return isFamilyAArch32(_arch); } + //! Tests whether this architecture family is AArch64. + ASMJIT_INLINE_NODEBUG bool isFamilyAArch64() const noexcept { return isFamilyAArch64(_arch); } + //! Tests whether this architecture family is MISP or MIPS64. + ASMJIT_INLINE_NODEBUG bool isFamilyMIPS() const noexcept { return isFamilyMIPS(_arch); } + //! Tests whether this architecture family is RISC-V (both 32-bit and 64-bit). + ASMJIT_INLINE_NODEBUG bool isFamilyRISCV() const noexcept { return isFamilyRISCV(_arch); } + + //! Tests whether the environment platform is Windows. + ASMJIT_INLINE_NODEBUG bool isPlatformWindows() const noexcept { return _platform == Platform::kWindows; } + //! Tests whether the environment platform is Linux. + ASMJIT_INLINE_NODEBUG bool isPlatformLinux() const noexcept { return _platform == Platform::kLinux; } + //! Tests whether the environment platform is Hurd. + ASMJIT_INLINE_NODEBUG bool isPlatformHurd() const noexcept { return _platform == Platform::kHurd; } + //! Tests whether the environment platform is Haiku. + ASMJIT_INLINE_NODEBUG bool isPlatformHaiku() const noexcept { return _platform == Platform::kHaiku; } + + //! Tests whether the environment platform is any BSD. + ASMJIT_INLINE_NODEBUG bool isPlatformBSD() const noexcept { + return _platform == Platform::kFreeBSD || + _platform == Platform::kOpenBSD || + _platform == Platform::kNetBSD || + _platform == Platform::kDragonFlyBSD; + } + + //! Tests whether the environment platform is any Apple platform (OSX, iOS, TVOS, WatchOS). + ASMJIT_INLINE_NODEBUG bool isPlatformApple() const noexcept { + return _platform == Platform::kOSX || + _platform == Platform::kIOS || + _platform == Platform::kTVOS || + _platform == Platform::kWatchOS; + } + + //! Tests whether the ABI is MSVC. + ASMJIT_INLINE_NODEBUG bool isMSVC() const noexcept { return _platformABI == PlatformABI::kMSVC; } + //! Tests whether the ABI is GNU. + ASMJIT_INLINE_NODEBUG bool isGNU() const noexcept { return _platformABI == PlatformABI::kGNU; } + //! Tests whether the ABI is GNU. + ASMJIT_INLINE_NODEBUG bool isDarwin() const noexcept { return _platformABI == PlatformABI::kDarwin; } + + //! Returns a calculated stack alignment for this environment. + ASMJIT_API uint32_t stackAlignment() const noexcept; + + //! Returns a native register size of this architecture. + ASMJIT_INLINE_NODEBUG uint32_t registerSize() const noexcept { return registerSizeFromArch(_arch); } + + //! Sets the architecture to `arch`. + ASMJIT_INLINE_NODEBUG void setArch(Arch arch) noexcept { _arch = arch; } + //! Sets the sub-architecture to `subArch`. + ASMJIT_INLINE_NODEBUG void setSubArch(SubArch subArch) noexcept { _subArch = subArch; } + //! Sets the vendor to `vendor`. + ASMJIT_INLINE_NODEBUG void setVendor(Vendor vendor) noexcept { _vendor = vendor; } + //! Sets the platform to `platform`. + ASMJIT_INLINE_NODEBUG void setPlatform(Platform platform) noexcept { _platform = platform; } + //! Sets the ABI to `platformABI`. + ASMJIT_INLINE_NODEBUG void setPlatformABI(PlatformABI platformABI) noexcept { _platformABI = platformABI; } + //! Sets the object format to `objectFormat`. + ASMJIT_INLINE_NODEBUG void setObjectFormat(ObjectFormat objectFormat) noexcept { _objectFormat = objectFormat; } + + //! Sets floating point ABI to `floatABI`. + ASMJIT_INLINE_NODEBUG void setFloatABI(FloatABI floatABI) noexcept { _floatABI = floatABI; } + + //! \} + + //! \name Static Utilities + //! \{ + + static ASMJIT_INLINE_NODEBUG bool isDefinedArch(Arch arch) noexcept { + return uint32_t(arch) <= uint32_t(Arch::kMaxValue); + } + + static ASMJIT_INLINE_NODEBUG bool isValidArch(Arch arch) noexcept { + return arch != Arch::kUnknown && uint32_t(arch) <= uint32_t(Arch::kMaxValue); + } + + //! Tests whether the given architecture `arch` is 32-bit. + static ASMJIT_INLINE_NODEBUG bool is32Bit(Arch arch) noexcept { + return (uint32_t(arch) & uint32_t(Arch::k32BitMask)) == uint32_t(Arch::k32BitMask); + } + + //! Tests whether the given architecture `arch` is 64-bit. + static ASMJIT_INLINE_NODEBUG bool is64Bit(Arch arch) noexcept { + return (uint32_t(arch) & uint32_t(Arch::k32BitMask)) == 0; + } + + //! Tests whether the given architecture `arch` is little endian. + static ASMJIT_INLINE_NODEBUG bool isLittleEndian(Arch arch) noexcept { + return uint32_t(arch) < uint32_t(Arch::kBigEndian); + } + + //! Tests whether the given architecture `arch` is big endian. + static ASMJIT_INLINE_NODEBUG bool isBigEndian(Arch arch) noexcept { + return uint32_t(arch) >= uint32_t(Arch::kBigEndian); + } + + //! Tests whether the given architecture is Thumb or Thumb_BE. + static ASMJIT_INLINE_NODEBUG bool isArchThumb(Arch arch) noexcept { + return arch == Arch::kThumb || arch == Arch::kThumb_BE; + } + + //! Tests whether the given architecture is ARM or ARM_BE. + static ASMJIT_INLINE_NODEBUG bool isArchARM(Arch arch) noexcept { + return arch == Arch::kARM || arch == Arch::kARM_BE; + } + + //! Tests whether the given architecture is AArch64 or AArch64_BE. + static ASMJIT_INLINE_NODEBUG bool isArchAArch64(Arch arch) noexcept { + return arch == Arch::kAArch64 || arch == Arch::kAArch64_BE; + } + + //! Tests whether the given architecture is MIPS32_LE or MIPS32_BE. + static ASMJIT_INLINE_NODEBUG bool isArchMIPS32(Arch arch) noexcept { + return arch == Arch::kMIPS32_LE || arch == Arch::kMIPS32_BE; + } + + //! Tests whether the given architecture is MIPS64_LE or MIPS64_BE. + static ASMJIT_INLINE_NODEBUG bool isArchMIPS64(Arch arch) noexcept { + return arch == Arch::kMIPS64_LE || arch == Arch::kMIPS64_BE; + } + + //! Tests whether the given architecture family is X86 or X64. + static ASMJIT_INLINE_NODEBUG bool isFamilyX86(Arch arch) noexcept { + return arch == Arch::kX86 || arch == Arch::kX64; + } + + //! Tests whether the given architecture family is AArch32 (ARM or THUMB). + static ASMJIT_INLINE_NODEBUG bool isFamilyAArch32(Arch arch) noexcept { + return isArchARM(arch) || isArchThumb(arch); + } + + //! Tests whether the given architecture family is AArch64. + static ASMJIT_INLINE_NODEBUG bool isFamilyAArch64(Arch arch) noexcept { + return isArchAArch64(arch); + } + + //! Tests whether the given architecture family is ARM, THUMB, or AArch64. + static ASMJIT_INLINE_NODEBUG bool isFamilyARM(Arch arch) noexcept { + return isFamilyAArch32(arch) || isFamilyAArch64(arch); + } + + //! Tests whether the given architecture family is MIPS or MIPS64. + static ASMJIT_INLINE_NODEBUG bool isFamilyMIPS(Arch arch) noexcept { + return isArchMIPS32(arch) || isArchMIPS64(arch); + } + + //! Tests whether the given architecture family is RISC-V (both 32-bit and 64-bit). + static ASMJIT_INLINE_NODEBUG bool isFamilyRISCV(Arch arch) noexcept { + return arch == Arch::kRISCV32 || arch == Arch::kRISCV64; + } + + //! Returns a native general purpose register size from the given architecture. + static ASMJIT_INLINE_NODEBUG uint32_t registerSizeFromArch(Arch arch) noexcept { + return is32Bit(arch) ? 4u : 8u; + } + + //! \} +}; + +static_assert(sizeof(Environment) == 8, + "Environment must occupy exactly 8 bytes."); + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ENVIRONMENT_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/errorhandler.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/errorhandler.h new file mode 100644 index 0000000000000000000000000000000000000000..a1a2dd2d5dbeeb52ca57a3c4954252c7e99aafee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/errorhandler.h @@ -0,0 +1,228 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ERRORHANDLER_H_INCLUDED +#define ASMJIT_CORE_ERRORHANDLER_H_INCLUDED + +#include "../core/globals.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_error_handling +//! \{ + +class BaseEmitter; + +//! Error handler can be used to override the default behavior of error handling. +//! +//! It's available to all classes that inherit `BaseEmitter`. Override \ref ErrorHandler::handleError() to implement +//! your own error handler. +//! +//! The following use-cases are supported: +//! +//! - Record the error and continue code generation. This is the simplest approach that can be used to at least log +//! possible errors. +//! - Throw an exception. AsmJit doesn't use exceptions and is completely exception-safe, but it's perfectly legal +//! to throw an exception from the error handler. +//! - Use plain old C's `setjmp()` and `longjmp()`. Asmjit always puts Assembler, Builder and Compiler to +//! a consistent state before calling \ref handleError(), so `longjmp()` can be used without issues to cancel the +//! code generation if an error occurred. This method can be used if exception handling in your project is turned +//! off and you still want some comfort. In most cases it should be safe as AsmJit uses \ref Zone memory and the +//! ownership of memory it allocates always ends with the instance that allocated it. If using this approach please +//! never jump outside the life-time of \ref CodeHolder and \ref BaseEmitter. +//! +//! \ref ErrorHandler can be attached to \ref CodeHolder or \ref BaseEmitter, which has a priority. The example below +//! uses error handler that just prints the error, but lets AsmJit continue: +//! +//! ``` +//! // Error Handling #1 - Logging and returning Error. +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Error handler that just prints the error and lets AsmJit ignore it. +//! class SimpleErrorHandler : public ErrorHandler { +//! public: +//! Error err; +//! +//! inline SimpleErrorHandler() : err(kErrorOk) {} +//! +//! void handleError(Error err, const char* message, BaseEmitter* origin) override { +//! this->err = err; +//! fprintf(stderr, "ERROR: %s\n", message); +//! } +//! }; +//! +//! int main() { +//! JitRuntime rt; +//! SimpleErrorHandler eh; +//! +//! CodeHolder code; +//! code.init(rt.environment(), rt.cpuFeatures()); +//! code.setErrorHandler(&eh); +//! +//! // Try to emit instruction that doesn't exist. +//! x86::Assembler a(&code); +//! a.emit(x86::Inst::kIdMov, x86::xmm0, x86::xmm1); +//! +//! if (eh.err) { +//! // Assembler failed! +//! return 1; +//! } +//! +//! return 0; +//! } +//! ``` +//! +//! If error happens during instruction emitting / encoding the assembler behaves transactionally - the output buffer +//! won't advance if encoding failed, thus either a fully encoded instruction or nothing is emitted. The error handling +//! shown above is useful, but it's still not the best way of dealing with errors in AsmJit. The following example +//! shows how to use exception handling to handle errors in a more C++ way: +//! +//! ``` +//! // Error Handling #2 - Throwing an exception. +//! #include +//! #include +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Error handler that throws a user-defined `AsmJitException`. +//! class AsmJitException : public std::exception { +//! public: +//! Error err; +//! std::string message; +//! +//! AsmJitException(Error err, const char* message) noexcept +//! : err(err), +//! message(message) {} +//! +//! const char* what() const noexcept override { return message.c_str(); } +//! }; +//! +//! class ThrowableErrorHandler : public ErrorHandler { +//! public: +//! // Throw is possible, functions that use ErrorHandler are never 'noexcept'. +//! void handleError(Error err, const char* message, BaseEmitter* origin) override { +//! throw AsmJitException(err, message); +//! } +//! }; +//! +//! int main() { +//! JitRuntime rt; +//! ThrowableErrorHandler eh; +//! +//! CodeHolder code; +//! code.init(rt.environment(), rt.cpuFeatures()); +//! code.setErrorHandler(&eh); +//! +//! x86::Assembler a(&code); +//! +//! // Try to emit instruction that doesn't exist. +//! try { +//! a.emit(x86::Inst::kIdMov, x86::xmm0, x86::xmm1); +//! } +//! catch (const AsmJitException& ex) { +//! printf("EXCEPTION THROWN: %s\n", ex.what()); +//! return 1; +//! } +//! +//! return 0; +//! } +//! ``` +//! +//! If C++ exceptions are not what you like or your project turns off them completely there is still a way of reducing +//! the error handling to a minimum by using a standard setjmp/longjmp approach. AsmJit is exception-safe and cleans +//! up everything before calling the ErrorHandler, so any approach is safe. You can simply jump from the error handler +//! without causing any side-effects or memory leaks. The following example demonstrates how it could be done: +//! +//! ``` +//! // Error Handling #3 - Using setjmp/longjmp if exceptions are not allowed. +//! #include +//! #include +//! #include +//! +//! class LongJmpErrorHandler : public asmjit::ErrorHandler { +//! public: +//! inline LongJmpErrorHandler() : err(asmjit::kErrorOk) {} +//! +//! void handleError(asmjit::Error err, const char* message, asmjit::BaseEmitter* origin) override { +//! this->err = err; +//! longjmp(state, 1); +//! } +//! +//! jmp_buf state; +//! asmjit::Error err; +//! }; +//! +//! int main(int argc, char* argv[]) { +//! using namespace asmjit; +//! +//! JitRuntime rt; +//! LongJmpErrorHandler eh; +//! +//! CodeHolder code; +//! code.init(rt.environment(), rt.cpuFeatures()); +//! code.setErrorHandler(&eh); +//! +//! x86::Assembler a(&code); +//! +//! if (!setjmp(eh.state)) { +//! // Try to emit instruction that doesn't exist. +//! a.emit(x86::Inst::kIdMov, x86::xmm0, x86::xmm1); +//! } +//! else { +//! Error err = eh.err; +//! printf("ASMJIT ERROR: 0x%08X [%s]\n", err, DebugUtils::errorAsString(err)); +//! } +//! +//! return 0; +//! } +//! ``` +class ASMJIT_VIRTAPI ErrorHandler { +public: + ASMJIT_BASE_CLASS(ErrorHandler) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `ErrorHandler` instance. + ASMJIT_API ErrorHandler() noexcept; + //! Destroys the `ErrorHandler` instance. + ASMJIT_API virtual ~ErrorHandler() noexcept; + + //! \} + + //! \name Interface + //! \{ + + //! Error handler (must be reimplemented). + //! + //! Error handler is called after an error happened and before it's propagated to the caller. There are multiple + //! ways how the error handler can be used: + //! + //! 1. User-based error handling without throwing exception or using C's`longjmp()`. This is for users that don't + //! use exceptions and want customized error handling. + //! + //! 2. Throwing an exception. AsmJit doesn't use exceptions and is completely exception-safe, but you can throw + //! exception from your error handler if this way is the preferred way of handling errors in your project. + //! + //! 3. Using plain old C's `setjmp()` and `longjmp()`. Asmjit always puts `BaseEmitter` to a consistent state before + //! calling `handleError()` so `longjmp()` can be used without any issues to cancel the code generation if an + //! error occurred. There is no difference between exceptions and `longjmp()` from AsmJit's perspective, however, + //! never jump outside of `CodeHolder` and `BaseEmitter` scope as you would leak memory. + ASMJIT_API virtual void handleError(Error err, const char* message, BaseEmitter* origin); + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ERRORHANDLER_H_INCLUDED + diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/formatter.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/formatter.h new file mode 100644 index 0000000000000000000000000000000000000000..392e47873020fc45d5bb66f6e1e2d3f91e7abbc7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/formatter.h @@ -0,0 +1,249 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_FORMATTER_H_INCLUDED +#define ASMJIT_CORE_FORMATTER_H_INCLUDED + +#include "../core/globals.h" +#include "../core/inst.h" +#include "../core/string.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_logging +//! \{ + +class BaseBuilder; +class BaseEmitter; +class BaseNode; +struct Operand_; + +//! Format flags used by \ref Logger and \ref FormatOptions. +enum class FormatFlags : uint32_t { + //! No formatting flags. + kNone = 0u, + + //! Show also binary form of each logged instruction (Assembler). + kMachineCode = 0x00000001u, + //! Show a text explanation of some immediate values. + kExplainImms = 0x00000002u, + //! Use hexadecimal notation of immediate values. + kHexImms = 0x00000004u, + //! Use hexadecimal notation of addresses and offsets in addresses. + kHexOffsets = 0x00000008u, + //! Show casts between virtual register types (Compiler output). + kRegCasts = 0x00000010u, + //! Show positions associated with nodes (Compiler output). + kPositions = 0x00000020u, + //! Always format a register type (Compiler output). + kRegType = 0x00000040u +}; +ASMJIT_DEFINE_ENUM_FLAGS(FormatFlags) + +//! Format indentation group, used by \ref FormatOptions. +enum class FormatIndentationGroup : uint32_t { + //! Indentation used for instructions and directives. + kCode = 0u, + //! Indentation used for labels and function nodes. + kLabel = 1u, + //! Indentation used for comments (not inline comments). + kComment = 2u, + + //! \cond INTERNAL + //! Reserved for future use. + kReserved = 3u, + //! \endcond + + //! Maximum value of `FormatIndentationGroup`. + kMaxValue = kReserved +}; + +//! Format padding group, used by \ref FormatOptions. +enum class FormatPaddingGroup : uint32_t { + //! Describes padding of a regular line, which can represent instruction, data, or assembler directives. + kRegularLine = 0, + //! Describes padding of machine code dump that is visible next to the instruction, if enabled. + kMachineCode = 1, + + //! Maximum value of `FormatPaddingGroup`. + kMaxValue = kMachineCode +}; + +//! Formatting options used by \ref Logger and \ref Formatter. +class FormatOptions { +public: + //! \name Members + //! \{ + + //! Format flags. + FormatFlags _flags = FormatFlags::kNone; + //! Indentations for each indentation group. + Support::Array _indentation {}; + //! Paddings for each padding group. + Support::Array _padding {}; + + //! \} + + //! \name Reset + //! \{ + + //! Resets FormatOptions to its default initialized state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _flags = FormatFlags::kNone; + _indentation.fill(uint8_t(0)); + _padding.fill(uint16_t(0)); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns format flags. + ASMJIT_INLINE_NODEBUG FormatFlags flags() const noexcept { return _flags; } + //! Tests whether the given `flag` is set in format flags. + ASMJIT_INLINE_NODEBUG bool hasFlag(FormatFlags flag) const noexcept { return Support::test(_flags, flag); } + + //! Resets all format flags to `flags`. + ASMJIT_INLINE_NODEBUG void setFlags(FormatFlags flags) noexcept { _flags = flags; } + //! Adds `flags` to format flags. + ASMJIT_INLINE_NODEBUG void addFlags(FormatFlags flags) noexcept { _flags |= flags; } + //! Removes `flags` from format flags. + ASMJIT_INLINE_NODEBUG void clearFlags(FormatFlags flags) noexcept { _flags &= ~flags; } + + //! Returns indentation for the given indentation `group`. + ASMJIT_INLINE_NODEBUG uint8_t indentation(FormatIndentationGroup group) const noexcept { return _indentation[group]; } + //! Sets indentation for the given indentation `group`. + ASMJIT_INLINE_NODEBUG void setIndentation(FormatIndentationGroup group, uint32_t n) noexcept { _indentation[group] = uint8_t(n); } + //! Resets indentation for the given indentation `group` to zero. + ASMJIT_INLINE_NODEBUG void resetIndentation(FormatIndentationGroup group) noexcept { _indentation[group] = uint8_t(0); } + + //! Returns padding for the given padding `group`. + ASMJIT_INLINE_NODEBUG size_t padding(FormatPaddingGroup group) const noexcept { return _padding[group]; } + //! Sets padding for the given padding `group`. + ASMJIT_INLINE_NODEBUG void setPadding(FormatPaddingGroup group, size_t n) noexcept { _padding[group] = uint16_t(n); } + //! Resets padding for the given padding `group` to zero, which means that a default padding will be used + //! based on the target architecture properties. + ASMJIT_INLINE_NODEBUG void resetPadding(FormatPaddingGroup group) noexcept { _padding[group] = uint16_t(0); } + + //! \} +}; + +//! Provides formatting functionality to format operands, instructions, and nodes. +namespace Formatter { + +#ifndef ASMJIT_NO_LOGGING + +//! Appends a formatted `typeId` to the output string `sb`. +ASMJIT_API Error formatTypeId( + String& sb, + TypeId typeId) noexcept; + +//! Appends a formatted `featureId` to the output string `sb`. +//! +//! See \ref CpuFeatures. +ASMJIT_API Error formatFeature( + String& sb, + Arch arch, + uint32_t featureId) noexcept; + +//! Appends a formatted register to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format virtual registers, which won't be formatted properly +//! if the `emitter` is not provided. +ASMJIT_API Error formatRegister( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + RegType regType, + uint32_t regId) noexcept; + +//! Appends a formatted label to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format named labels properly, otherwise the formatted as +//! it is an anonymous label. +ASMJIT_API Error formatLabel( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + uint32_t labelId) noexcept; + +//! Appends a formatted operand to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format named labels and virtual registers. See +//! \ref formatRegister() and \ref formatLabel() for more details. +ASMJIT_API Error formatOperand( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + const Operand_& op) noexcept; + +//! Appends a formatted data-type to the output string `sb`. +ASMJIT_API Error formatDataType( + String& sb, + FormatFlags formatFlags, + Arch arch, + TypeId typeId) noexcept; + +//! Appends a formatted data to the output string `sb`. +ASMJIT_API Error formatData( + String& sb, + FormatFlags formatFlags, + Arch arch, + TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1) noexcept; + +//! Appends a formatted instruction to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format named labels and virtual registers. See +//! \ref formatRegister() and \ref formatLabel() for more details. +ASMJIT_API Error formatInstruction( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + const BaseInst& inst, const Operand_* operands, size_t opCount) noexcept; + +#ifndef ASMJIT_NO_BUILDER +//! Appends a formatted node to the output string `sb`. +//! +//! The `node` must belong to the provided `builder`. +ASMJIT_API Error formatNode( + String& sb, + const FormatOptions& formatOptions, + const BaseBuilder* builder, + const BaseNode* node) noexcept; + +//! Appends formatted nodes to the output string `sb`. +//! +//! All nodes that are part of the given `builder` will be appended. +ASMJIT_API Error formatNodeList( + String& sb, + const FormatOptions& formatOptions, + const BaseBuilder* builder) noexcept; + +//! Appends formatted nodes to the output string `sb`. +//! +//! This function works the same as \ref formatNode(), but appends more nodes to the output string, +//! separating each node with a newline '\n' character. +ASMJIT_API Error formatNodeList( + String& sb, + const FormatOptions& formatOptions, + const BaseBuilder* builder, + const BaseNode* begin, + const BaseNode* end) noexcept; +#endif + +#endif + +} // {Formatter} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_FORMATTER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/func.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/func.h new file mode 100644 index 0000000000000000000000000000000000000000..695a23bbcf0d380bbc7fd596b8df11427e853704 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/func.h @@ -0,0 +1,1595 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_FUNC_H_INCLUDED +#define ASMJIT_CORE_FUNC_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/environment.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_function +//! \{ + +//! Calling convention id. +//! +//! Calling conventions can be divided into the following groups: +//! +//! - Universal - calling conventions are applicable to any target. They will be converted to a target dependent +//! calling convention at runtime by \ref CallConv::init() with some help from \ref Environment. The purpose of +//! these calling conventions is to make using functions less target dependent and closer to C and C++. +//! +//! - Target specific - calling conventions that are used by a particular architecture and ABI. For example +//! Windows 64-bit calling convention and AMD64 SystemV calling convention. +enum class CallConvId : uint8_t { + // Universal Calling Conventions + // ----------------------------- + + //! Standard function call or explicit `__cdecl` where it can be specified. + //! + //! This is a universal calling convention, which is used to initialize specific calling conventions based on + //! architecture, platform, and its ABI. + kCDecl = 0, + + //! `__stdcall` on targets that support this calling convention (X86). + //! + //! \note This calling convention is only supported on 32-bit X86. If used on environment that doesn't support + //! this calling convention it will be replaced by \ref CallConvId::kCDecl. + kStdCall = 1, + + //! `__fastcall` on targets that support this calling convention (X86). + //! + //! \note This calling convention is only supported on 32-bit X86. If used on environment that doesn't support + //! this calling convention it will be replaced by \ref CallConvId::kCDecl. + kFastCall = 2, + + //! `__vectorcall` on targets that support this calling convention (X86/X64). + //! + //! \note This calling convention is only supported on 32-bit and 64-bit X86 architecture on Windows platform. + //! If used on environment that doesn't support this calling it will be replaced by \ref CallConvId::kCDecl. + kVectorCall = 3, + + //! `__thiscall` on targets that support this calling convention (X86). + //! + //! \note This calling convention is only supported on 32-bit X86 Windows platform. If used on environment that + //! doesn't support this calling convention it will be replaced by \ref CallConvId::kCDecl. + kThisCall = 4, + + //! `__attribute__((regparm(1)))` convention (GCC and Clang). + kRegParm1 = 5, + //! `__attribute__((regparm(2)))` convention (GCC and Clang). + kRegParm2 = 6, + //! `__attribute__((regparm(3)))` convention (GCC and Clang). + kRegParm3 = 7, + + //! AsmJit specific calling convention designed for calling functions inside a multimedia code that don't use many + //! registers internally, but are long enough to be called and not inlined. These functions are usually used to + //! calculate trigonometric functions, logarithms, etc... + kLightCall2 = 16, + kLightCall3 = 17, + kLightCall4 = 18, + + // ABI-Specific Calling Conventions + // -------------------------------- + + //! Soft-float calling convention (AArch32). + //! + //! Floating point arguments are passed via general purpose registers. + kSoftFloat = 30, + + //! Hard-float calling convention (AArch32). + //! + //! Floating point arguments are passed via SIMD registers. + kHardFloat = 31, + + //! X64 System-V calling convention. + kX64SystemV = 32, + //! X64 Windows calling convention. + kX64Windows = 33, + + //! Maximum value of `CallConvId`. + kMaxValue = kX64Windows + + // Deprecated Aliases + // ------------------ + +#if !defined(ASMJIT_NO_DEPRECATED) + , + kNone = kCDecl, + kHost = kCDecl +#endif // !ASMJIT_NO_DEPRECATED +}; + +//! Strategy used by calling conventions to assign registers to function arguments. +//! +//! Calling convention strategy describes how AsmJit should convert function arguments used by \ref FuncSignature +//! into register identifiers and stack offsets. The \ref CallConvStrategy::kDefault strategy assigns registers +//! and then stack whereas \ref CallConvStrategy::kX64Windows strategy does register shadowing as defined by WIN64 +//! calling convention, which is only used by 64-bit Windows. +enum class CallConvStrategy : uint8_t { + //! Default register assignment strategy. + kDefault = 0, + //! Windows 64-bit ABI register assignment strategy. + kX64Windows = 1, + //! Windows 64-bit __vectorcall register assignment strategy. + kX64VectorCall = 2, + //! Apple's AArch64 calling convention (differs compared to AArch64 calling convention used by Linux). + kAArch64Apple = 3, + + //! Maximum value of `CallConvStrategy`. + kMaxValue = kX64VectorCall +}; + +//! Calling convention flags. +enum class CallConvFlags : uint32_t { + //! No flags. + kNone = 0, + //! Callee is responsible for cleaning up the stack. + kCalleePopsStack = 0x0001u, + //! Pass vector arguments indirectly (as a pointer). + kIndirectVecArgs = 0x0002u, + //! Pass F32 and F64 arguments via VEC128 register. + kPassFloatsByVec = 0x0004u, + //! Pass MMX and vector arguments via stack if the function has variable arguments. + kPassVecByStackIfVA = 0x0008u, + //! MMX registers are passed and returned via GP registers. + kPassMmxByGp = 0x0010u, + //! MMX registers are passed and returned via XMM registers. + kPassMmxByXmm = 0x0020u, + //! Calling convention can be used with variable arguments. + kVarArgCompatible = 0x0080u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CallConvFlags) + +//! Function calling convention. +//! +//! Function calling convention is a scheme that defines how function parameters are passed and how function +//! returns its result. AsmJit defines a variety of architecture and OS specific calling conventions and also +//! provides a compile time detection to make the code-generation easier. +struct CallConv { + //! \name Constants + //! \{ + + //! Maximum number of register arguments per register group. + //! + //! \note This is not really AsmJit's limitation, it's just the number that makes sense considering all common + //! calling conventions. Usually even conventions that use registers to pass function arguments are limited to 8 + //! and less arguments passed via registers per group. + static constexpr uint32_t kMaxRegArgsPerGroup = 16; + + //! \} + + //! \name Members + //! \{ + + //! Target architecture. + Arch _arch; + //! Calling convention id. + CallConvId _id; + //! Register assignment strategy. + CallConvStrategy _strategy; + + //! Red zone size (AMD64 == 128 bytes). + uint8_t _redZoneSize; + //! Spill zone size (WIN-X64 == 32 bytes). + uint8_t _spillZoneSize; + //! Natural stack alignment as defined by OS/ABI. + uint8_t _naturalStackAlignment; + + //! \cond INTERNAL + //! Reserved for future use. + uint8_t _reserved[2]; + //! \endcond + + //! Calling convention flags. + CallConvFlags _flags; + + //! Size to save/restore per register group. + Support::Array _saveRestoreRegSize; + //! Alignment of save/restore groups. + Support::Array _saveRestoreAlignment; + + //! Mask of all passed registers, per group. + Support::Array _passedRegs; + //! Mask of all preserved registers, per group. + Support::Array _preservedRegs; + + //! Passed registers' order. + union RegOrder { + //! Passed registers, ordered. + uint8_t id[kMaxRegArgsPerGroup]; + //! Packed IDs in `uint32_t` array. + uint32_t packed[(kMaxRegArgsPerGroup + 3) / 4]; + }; + + //! Passed registers' order, per register group. + Support::Array _passedOrder; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Initializes this calling convention to the given `ccId` based on the `environment`. + //! + //! See \ref CallConvId and \ref Environment for more details. + ASMJIT_API Error init(CallConvId ccId, const Environment& environment) noexcept; + + //! Resets this CallConv struct into a defined state. + //! + //! It's recommended to reset the \ref CallConv struct in case you would like create a custom calling convention + //! as it prevents from using an uninitialized data (CallConv doesn't have a constructor that would initialize it, + //! it's just a struct). + ASMJIT_INLINE_NODEBUG void reset() noexcept { + *this = CallConv{}; + memset(_passedOrder.data(), 0xFF, sizeof(_passedOrder)); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the target architecture of this calling convention. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + //! Sets the target architecture of this calling convention. + ASMJIT_INLINE_NODEBUG void setArch(Arch arch) noexcept { _arch = arch; } + + //! Returns the calling convention id. + ASMJIT_INLINE_NODEBUG CallConvId id() const noexcept { return _id; } + //! Sets the calling convention id. + ASMJIT_INLINE_NODEBUG void setId(CallConvId ccId) noexcept { _id = ccId; } + + //! Returns the strategy used to assign registers to arguments. + ASMJIT_INLINE_NODEBUG CallConvStrategy strategy() const noexcept { return _strategy; } + //! Sets the strategy used to assign registers to arguments. + ASMJIT_INLINE_NODEBUG void setStrategy(CallConvStrategy ccStrategy) noexcept { _strategy = ccStrategy; } + + //! Tests whether the calling convention has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(CallConvFlags flag) const noexcept { return Support::test(_flags, flag); } + //! Returns the calling convention flags, see `Flags`. + ASMJIT_INLINE_NODEBUG CallConvFlags flags() const noexcept { return _flags; } + //! Adds the calling convention flags, see `Flags`. + ASMJIT_INLINE_NODEBUG void setFlags(CallConvFlags flag) noexcept { _flags = flag; }; + //! Adds the calling convention flags, see `Flags`. + ASMJIT_INLINE_NODEBUG void addFlags(CallConvFlags flags) noexcept { _flags |= flags; }; + + //! Tests whether this calling convention specifies 'RedZone'. + ASMJIT_INLINE_NODEBUG bool hasRedZone() const noexcept { return _redZoneSize != 0; } + //! Tests whether this calling convention specifies 'SpillZone'. + ASMJIT_INLINE_NODEBUG bool hasSpillZone() const noexcept { return _spillZoneSize != 0; } + + //! Returns size of 'RedZone'. + ASMJIT_INLINE_NODEBUG uint32_t redZoneSize() const noexcept { return _redZoneSize; } + //! Returns size of 'SpillZone'. + ASMJIT_INLINE_NODEBUG uint32_t spillZoneSize() const noexcept { return _spillZoneSize; } + + //! Sets size of 'RedZone'. + ASMJIT_INLINE_NODEBUG void setRedZoneSize(uint32_t size) noexcept { _redZoneSize = uint8_t(size); } + //! Sets size of 'SpillZone'. + ASMJIT_INLINE_NODEBUG void setSpillZoneSize(uint32_t size) noexcept { _spillZoneSize = uint8_t(size); } + + //! Returns a natural stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t naturalStackAlignment() const noexcept { return _naturalStackAlignment; } + //! Sets a natural stack alignment. + //! + //! This function can be used to override the default stack alignment in case that you know that it's alignment is + //! different. For example it allows to implement custom calling conventions that guarantee higher stack alignment. + ASMJIT_INLINE_NODEBUG void setNaturalStackAlignment(uint32_t value) noexcept { _naturalStackAlignment = uint8_t(value); } + + //! Returns the size of a register (or its part) to be saved and restored of the given `group`. + ASMJIT_INLINE_NODEBUG uint32_t saveRestoreRegSize(RegGroup group) const noexcept { return _saveRestoreRegSize[group]; } + //! Sets the size of a vector register (or its part) to be saved and restored. + ASMJIT_INLINE_NODEBUG void setSaveRestoreRegSize(RegGroup group, uint32_t size) noexcept { _saveRestoreRegSize[group] = uint8_t(size); } + + //! Returns the alignment of a save-restore area of the given `group`. + ASMJIT_INLINE_NODEBUG uint32_t saveRestoreAlignment(RegGroup group) const noexcept { return _saveRestoreAlignment[group]; } + //! Sets the alignment of a save-restore area of the given `group`. + ASMJIT_INLINE_NODEBUG void setSaveRestoreAlignment(RegGroup group, uint32_t alignment) noexcept { _saveRestoreAlignment[group] = uint8_t(alignment); } + + //! Returns the order of passed registers of the given `group`. + inline const uint8_t* passedOrder(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _passedOrder[size_t(group)].id; + } + + //! Returns the mask of passed registers of the given `group`. + inline RegMask passedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _passedRegs[size_t(group)]; + } + + inline void _setPassedPacked(RegGroup group, uint32_t p0, uint32_t p1, uint32_t p2, uint32_t p3) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + + _passedOrder[group].packed[0] = p0; + _passedOrder[group].packed[1] = p1; + _passedOrder[group].packed[2] = p2; + _passedOrder[group].packed[3] = p3; + } + + //! Resets the order and mask of passed registers. + inline void setPassedToNone(RegGroup group) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + + _setPassedPacked(group, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu); + _passedRegs[size_t(group)] = 0u; + } + + //! Sets the order and mask of passed registers. + inline void setPassedOrder(RegGroup group, uint32_t a0, uint32_t a1 = 0xFF, uint32_t a2 = 0xFF, uint32_t a3 = 0xFF, uint32_t a4 = 0xFF, uint32_t a5 = 0xFF, uint32_t a6 = 0xFF, uint32_t a7 = 0xFF) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + + // NOTE: This should always be called with all arguments known at compile time, so even if it looks scary it + // should be translated into few instructions. + _setPassedPacked(group, Support::bytepack32_4x8(a0, a1, a2, a3), + Support::bytepack32_4x8(a4, a5, a6, a7), + 0xFFFFFFFFu, + 0xFFFFFFFFu); + + _passedRegs[group] = (a0 != 0xFF ? 1u << a0 : 0u) | + (a1 != 0xFF ? 1u << a1 : 0u) | + (a2 != 0xFF ? 1u << a2 : 0u) | + (a3 != 0xFF ? 1u << a3 : 0u) | + (a4 != 0xFF ? 1u << a4 : 0u) | + (a5 != 0xFF ? 1u << a5 : 0u) | + (a6 != 0xFF ? 1u << a6 : 0u) | + (a7 != 0xFF ? 1u << a7 : 0u) ; + } + + //! Returns preserved register mask of the given `group`. + inline RegMask preservedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _preservedRegs[group]; + } + + //! Sets preserved register mask of the given `group`. + inline void setPreservedRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _preservedRegs[group] = regs; + } + + //! \} +}; + +//! Function signature. +//! +//! Contains information about a function return type, count of arguments, and their TypeIds. Function signature +//! is a low level structure which doesn't contain platform specific or calling convention specific information. +//! It's typically used to describe function arguments in a C-API like form, which is then used to calculate a +//! \ref FuncDetail instance, which then maps function signature into a platform and calling convention specific +//! format. +//! +//! Function signature can be built either dynamically by using \ref addArg() and \ref addArgT() functionality, +//! or dynamically by using a template-based \ref FuncSignature::build() function, which maps template types +//! into a function signature. +struct FuncSignature { + //! \name Constants + //! \{ + + //! Doesn't have variable number of arguments (`...`). + static constexpr uint8_t kNoVarArgs = 0xFFu; + + //! \} + + //! \name Members + //! \{ + + //! Calling convention id. + CallConvId _ccId = CallConvId::kCDecl; + //! Count of arguments. + uint8_t _argCount = 0; + //! Index of a first VA or `kNoVarArgs`. + uint8_t _vaIndex = kNoVarArgs; + //! Return value TypeId. + TypeId _ret = TypeId::kVoid; + //! Reserved for future use. + uint8_t _reserved[4] {}; + //! Function argument TypeIds. + TypeId _args[Globals::kMaxFuncArgs] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Default constructed function signature, initialized to \ref CallConvId::kCDecl, having no return value and no arguments. + ASMJIT_FORCE_INLINE constexpr FuncSignature() = default; + + //! Copy constructor, which is initialized to the same function signature as `other`. + ASMJIT_FORCE_INLINE constexpr FuncSignature(const FuncSignature& other) = default; + + //! Initializes the function signature with calling convention id `ccId` and variable argument's index `vaIndex`. + ASMJIT_FORCE_INLINE constexpr FuncSignature(CallConvId ccId, uint32_t vaIndex = kNoVarArgs) noexcept + : _ccId(ccId), + _vaIndex(uint8_t(vaIndex)) {} + + //! Initializes the function signature with calling convention id `ccId`, `vaIndex`, return value, and function arguments. + template + ASMJIT_FORCE_INLINE constexpr FuncSignature(CallConvId ccId, uint32_t vaIndex, TypeId ret, Args&&...args) noexcept + : _ccId(ccId), + _argCount(uint8_t(sizeof...(args))), + _vaIndex(uint8_t(vaIndex)), + _ret(ret), + _args{std::forward(args)...} {} + + //! Builds a function signature based on `RetValueAndArgs`. The first template argument is a function return type, + //! and function arguments follow. + //! + //! \note This function returns a new function signature, which can be passed to functions where it's required. It's + //! a convenience function that allows to build function signature statically based on types known at compile time, + //! which is common in JIT code generation. + template + static ASMJIT_INLINE_NODEBUG constexpr FuncSignature build(CallConvId ccId = CallConvId::kCDecl, uint32_t vaIndex = kNoVarArgs) noexcept { + return FuncSignature(ccId, vaIndex, (TypeId(TypeUtils::TypeIdOfT::kTypeId))... ); + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment - function signature can be copied by value. + ASMJIT_FORCE_INLINE FuncSignature& operator=(const FuncSignature& other) noexcept = default; + + //! Compares this function signature with `other` for equality.. + ASMJIT_FORCE_INLINE bool operator==(const FuncSignature& other) const noexcept { return equals(other); } + //! Compares this function signature with `other` for inequality.. + ASMJIT_FORCE_INLINE bool operator!=(const FuncSignature& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Resets this function signature to a default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = FuncSignature{}; } + + //! \} + + //! \name Equality & Comparison + //! \{ + + //! Compares this function signature with `other` for equality.. + ASMJIT_INLINE_NODEBUG bool equals(const FuncSignature& other) const noexcept { + return _ccId == other._ccId && + _argCount == other._argCount && + _vaIndex == other._vaIndex && + _ret == other._ret && + memcmp(_args, other._args, sizeof(_args)) == 0; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the calling convention. + ASMJIT_INLINE_NODEBUG CallConvId callConvId() const noexcept { return _ccId; } + //! Sets the calling convention to `ccId`; + ASMJIT_INLINE_NODEBUG void setCallConvId(CallConvId ccId) noexcept { _ccId = ccId; } + + //! Tests whether the function signature has a return value. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return _ret != TypeId::kVoid; } + //! Returns the type of the return value. + ASMJIT_INLINE_NODEBUG TypeId ret() const noexcept { return _ret; } + //! Sets the return type to `retType`. + ASMJIT_INLINE_NODEBUG void setRet(TypeId retType) noexcept { _ret = retType; } + //! Sets the return type based on `T`. + template + ASMJIT_INLINE_NODEBUG void setRetT() noexcept { setRet(TypeId(TypeUtils::TypeIdOfT::kTypeId)); } + + + //! Returns the array of function arguments' types. + ASMJIT_INLINE_NODEBUG const TypeId* args() const noexcept { return _args; } + //! Returns the number of function arguments. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _argCount; } + + //! Returns the type of the argument at index `i`. + inline TypeId arg(uint32_t i) const noexcept { + ASMJIT_ASSERT(i < _argCount); + return _args[i]; + } + + //! Sets the argument at index `index` to `argType`. + inline void setArg(uint32_t index, TypeId argType) noexcept { + ASMJIT_ASSERT(index < _argCount); + _args[index] = argType; + } + //! Sets the argument at index `i` to the type based on `T`. + template + inline void setArgT(uint32_t index) noexcept { setArg(index, TypeId(TypeUtils::TypeIdOfT::kTypeId)); } + + //! Tests whether an argument can be added to the signature, use before calling \ref addArg() and \ref addArgT(). + //! + //! \note If you know that you are not adding more arguments than \ref Globals::kMaxFuncArgs then it's not necessary + //! to use this function. However, if you are adding arguments based on user input, for example, then either check + //! the number of arguments before using function signature or use \ref canAddArg() before actually adding them to + //! the function signature. + inline bool canAddArg() const noexcept { return _argCount < Globals::kMaxFuncArgs; } + + //! Appends an argument of `type` to the function prototype. + inline void addArg(TypeId type) noexcept { + ASMJIT_ASSERT(_argCount < Globals::kMaxFuncArgs); + _args[_argCount++] = type; + } + + //! Appends an argument of type based on `T` to the function prototype. + template + inline void addArgT() noexcept { addArg(TypeId(TypeUtils::TypeIdOfT::kTypeId)); } + + //! Tests whether the function has variable number of arguments (...). + ASMJIT_INLINE_NODEBUG bool hasVarArgs() const noexcept { return _vaIndex != kNoVarArgs; } + //! Returns the variable arguments (...) index, `kNoVarArgs` if none. + ASMJIT_INLINE_NODEBUG uint32_t vaIndex() const noexcept { return _vaIndex; } + //! Sets the variable arguments (...) index to `index`. + ASMJIT_INLINE_NODEBUG void setVaIndex(uint32_t index) noexcept { _vaIndex = uint8_t(index); } + //! Resets the variable arguments index (making it a non-va function). + ASMJIT_INLINE_NODEBUG void resetVaIndex() noexcept { _vaIndex = kNoVarArgs; } + + //! \} +}; + +#if !defined(ASMJIT_NO_DEPRECATED) +template +class FuncSignatureT : public FuncSignature { +public: + ASMJIT_DEPRECATED("Use FuncSignature::build() instead") + ASMJIT_INLINE_NODEBUG constexpr FuncSignatureT(CallConvId ccId = CallConvId::kCDecl, uint32_t vaIndex = kNoVarArgs) noexcept + : FuncSignature(ccId, vaIndex, (TypeId(TypeUtils::TypeIdOfT::kTypeId))... ) {} +}; + +ASMJIT_DEPRECATED("Use FuncSignature instead of FuncSignatureBuilder") +typedef FuncSignature FuncSignatureBuilder; +#endif // !ASMJIT_NO_DEPRECATED + +//! Argument or return value (or its part) as defined by `FuncSignature`, but with register or stack address +//! (and other metadata) assigned. +struct FuncValue { + //! \name Constants + //! \{ + + enum Bits : uint32_t { + kTypeIdShift = 0, //!< TypeId shift. + kTypeIdMask = 0x000000FFu, //!< TypeId mask. + + kFlagIsReg = 0x00000100u, //!< Passed by register. + kFlagIsStack = 0x00000200u, //!< Passed by stack. + kFlagIsIndirect = 0x00000400u, //!< Passed indirectly by reference (internally a pointer). + kFlagIsDone = 0x00000800u, //!< Used internally by arguments allocator. + + kStackOffsetShift = 12, //!< Stack offset shift. + kStackOffsetMask = 0xFFFFF000u, //!< Stack offset mask (must occupy MSB bits). + + kRegIdShift = 16, //!< RegId shift. + kRegIdMask = 0x00FF0000u, //!< RegId mask. + + kRegTypeShift = 24, //!< RegType shift. + kRegTypeMask = 0xFF000000u //!< RegType mask. + }; + + //! \} + + //! \name Members + //! \{ + + uint32_t _data; + + //! \} + + //! \name Initialization & Reset + //! + //! These initialize the whole `FuncValue` to either register or stack. Useful when you know all of these + //! properties and wanna just set it up. + //! + //! \{ + + //! Initializes this `FuncValue` only to the `typeId` provided - the rest of the values will be cleared. + ASMJIT_INLINE_NODEBUG void initTypeId(TypeId typeId) noexcept { + _data = uint32_t(typeId) << kTypeIdShift; + } + + //! Initializes this `FuncValue` to a register of `regType`, `regId`, and assigns its `typeId` and `flags`. + ASMJIT_INLINE_NODEBUG void initReg(RegType regType, uint32_t regId, TypeId typeId, uint32_t flags = 0) noexcept { + _data = (uint32_t(regType) << kRegTypeShift) | (regId << kRegIdShift) | (uint32_t(typeId) << kTypeIdShift) | kFlagIsReg | flags; + } + + //! Initializes this `FuncValue` to a stack at the given `offset` and assigns its `typeId`. + ASMJIT_INLINE_NODEBUG void initStack(int32_t offset, TypeId typeId) noexcept { + _data = (uint32_t(offset) << kStackOffsetShift) | (uint32_t(typeId) << kTypeIdShift) | kFlagIsStack; + } + + //! Resets the value to its unassigned state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { _data = 0; } + + //! \} + + //! \name Assign + //! + //! These initialize only part of `FuncValue`, useful when building `FuncValue` incrementally. The caller + //! should first init the type-id by calling `initTypeId` and then continue building either register or stack. + //! + //! \{ + + //! Assigns a register of `regType` and `regId`. + inline void assignRegData(RegType regType, uint32_t regId) noexcept { + ASMJIT_ASSERT((_data & (kRegTypeMask | kRegIdMask)) == 0); + _data |= (uint32_t(regType) << kRegTypeShift) | (regId << kRegIdShift) | kFlagIsReg; + } + + //! Assigns a stack location at `offset`. + inline void assignStackOffset(int32_t offset) noexcept { + ASMJIT_ASSERT((_data & kStackOffsetMask) == 0); + _data |= (uint32_t(offset) << kStackOffsetShift) | kFlagIsStack; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns true if the value is initialized (explicit bool cast). + ASMJIT_INLINE_NODEBUG explicit operator bool() const noexcept { return _data != 0; } + + //! \cond INTERNAL + ASMJIT_INLINE_NODEBUG void _replaceValue(uint32_t mask, uint32_t value) noexcept { _data = (_data & ~mask) | value; } + //! \endcond + + //! Tests whether the `FuncValue` has a flag `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(uint32_t flag) const noexcept { return Support::test(_data, flag); } + //! Adds `flags` to `FuncValue`. + ASMJIT_INLINE_NODEBUG void addFlags(uint32_t flags) noexcept { _data |= flags; } + //! Clears `flags` of `FuncValue`. + ASMJIT_INLINE_NODEBUG void clearFlags(uint32_t flags) noexcept { _data &= ~flags; } + + //! Tests whether the value is initialized (i.e. contains a valid data). + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _data != 0; } + //! Tests whether the argument is passed by register. + ASMJIT_INLINE_NODEBUG bool isReg() const noexcept { return hasFlag(kFlagIsReg); } + //! Tests whether the argument is passed by stack. + ASMJIT_INLINE_NODEBUG bool isStack() const noexcept { return hasFlag(kFlagIsStack); } + //! Tests whether the argument is passed by register. + ASMJIT_INLINE_NODEBUG bool isAssigned() const noexcept { return hasFlag(kFlagIsReg | kFlagIsStack); } + //! Tests whether the argument is passed through a pointer (used by WIN64 to pass XMM|YMM|ZMM). + ASMJIT_INLINE_NODEBUG bool isIndirect() const noexcept { return hasFlag(kFlagIsIndirect); } + + //! Tests whether the argument was already processed (used internally). + ASMJIT_INLINE_NODEBUG bool isDone() const noexcept { return hasFlag(kFlagIsDone); } + + //! Returns a register type of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG RegType regType() const noexcept { return RegType((_data & kRegTypeMask) >> kRegTypeShift); } + //! Sets a register type of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG void setRegType(RegType regType) noexcept { _replaceValue(kRegTypeMask, uint32_t(regType) << kRegTypeShift); } + + //! Returns a physical id of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG uint32_t regId() const noexcept { return (_data & kRegIdMask) >> kRegIdShift; } + //! Sets a physical id of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG void setRegId(uint32_t regId) noexcept { _replaceValue(kRegIdMask, regId << kRegIdShift); } + + //! Returns a stack offset of this argument. + ASMJIT_INLINE_NODEBUG int32_t stackOffset() const noexcept { return int32_t(_data & kStackOffsetMask) >> kStackOffsetShift; } + //! Sets a stack offset of this argument. + ASMJIT_INLINE_NODEBUG void setStackOffset(int32_t offset) noexcept { _replaceValue(kStackOffsetMask, uint32_t(offset) << kStackOffsetShift); } + + //! Tests whether the argument or return value has associated `TypeId`. + ASMJIT_INLINE_NODEBUG bool hasTypeId() const noexcept { return Support::test(_data, kTypeIdMask); } + //! Returns a TypeId of this argument or return value. + ASMJIT_INLINE_NODEBUG TypeId typeId() const noexcept { return TypeId((_data & kTypeIdMask) >> kTypeIdShift); } + //! Sets a TypeId of this argument or return value. + ASMJIT_INLINE_NODEBUG void setTypeId(TypeId typeId) noexcept { _replaceValue(kTypeIdMask, uint32_t(typeId) << kTypeIdShift); } + + //! \} +}; + +//! Contains multiple `FuncValue` instances in an array so functions that use multiple registers for arguments or +//! return values can represent all inputs and outputs. +struct FuncValuePack { +public: + //! \name Members + //! \{ + + //! Values of the pack. + FuncValue _values[Globals::kMaxValuePack]; + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Resets all values in the pack. + inline void reset() noexcept { + for (size_t i = 0; i < Globals::kMaxValuePack; i++) + _values[i].reset(); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Calculates how many values are in the pack, checking for non-values from the end. + inline uint32_t count() const noexcept { + uint32_t n = Globals::kMaxValuePack; + while (n && !_values[n - 1]) + n--; + return n; + } + + //! Returns values in this value in the pack. + //! + //! \note The returned array has exactly \ref Globals::kMaxValuePack elements. + ASMJIT_INLINE_NODEBUG FuncValue* values() noexcept { return _values; } + //! \overload + ASMJIT_INLINE_NODEBUG const FuncValue* values() const noexcept { return _values; } + + //! Resets a value at the given `index` in the pack, which makes it unassigned. + inline void resetValue(size_t index) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + _values[index].reset(); + } + + //! Tests whether the value at the given `index` in the pack is assigned. + inline bool hasValue(size_t index) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + return _values[index].isInitialized(); + } + + //! Assigns a register at the given `index` to `reg` and an optional `typeId`. + inline void assignReg(size_t index, const BaseReg& reg, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + ASMJIT_ASSERT(reg.isPhysReg()); + _values[index].initReg(reg.type(), reg.id(), typeId); + } + + //! Assigns a register at the given `index` to `regType`, `regId`, and an optional `typeId`. + inline void assignReg(size_t index, RegType regType, uint32_t regId, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + _values[index].initReg(regType, regId, typeId); + } + + //! Assigns a stack location at the given `index` to `offset` and an optional `typeId`. + inline void assignStack(size_t index, int32_t offset, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + _values[index].initStack(offset, typeId); + } + + //! Accesses the value in the pack at the given `index`. + //! + //! \note The maximum index value is `Globals::kMaxValuePack - 1`. + inline FuncValue& operator[](size_t index) { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + return _values[index]; + } + //! \overload + inline const FuncValue& operator[](size_t index) const { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + return _values[index]; + } + + //! \} +}; + +//! Attributes are designed in a way that all are initially false, and user or \ref FuncFrame finalizer adds +//! them when necessary. +enum class FuncAttributes : uint32_t { + //! No attributes. + kNoAttributes = 0, + + //! Function has variable number of arguments. + kHasVarArgs = 0x00000001u, + //! Preserve frame pointer (don't omit FP). + kHasPreservedFP = 0x00000010u, + //! Function calls other functions (is not leaf). + kHasFuncCalls = 0x00000020u, + //! Function has aligned save/restore of vector registers. + kAlignedVecSR = 0x00000040u, + //! Function must begin with an instruction that marks a start of a branch or function. + //! + //! * `ENDBR32/ENDBR64` instruction is inserted at the beginning of the function (X86, X86_64). + //! * `BTI` instruction is inserted at the beginning of the function (AArch64) + kIndirectBranchProtection = 0x00000080u, + //! FuncFrame is finalized and can be used by prolog/epilog inserter (PEI). + kIsFinalized = 0x00000800u, + + // X86 Specific Attributes + // ----------------------- + + //! Enables the use of AVX within the function's body, prolog, and epilog (X86). + //! + //! This flag instructs prolog and epilog emitter to use AVX instead of SSE for manipulating XMM registers. + kX86_AVXEnabled = 0x00010000u, + + //! Enables the use of AVX-512 within the function's body, prolog, and epilog (X86). + //! + //! This flag instructs Compiler register allocator to use additional 16 registers introduced by AVX-512. + //! Additionally, if the functions saves full width of ZMM registers (custom calling conventions only) then + //! the prolog/epilog inserter would use AVX-512 move instructions to emit the save and restore sequence. + kX86_AVX512Enabled = 0x00020000u, + + //! This flag instructs the epilog writer to emit EMMS instruction before RET (X86). + kX86_MMXCleanup = 0x00040000u, + + //! This flag instructs the epilog writer to emit VZEROUPPER instruction before RET (X86). + kX86_AVXCleanup = 0x00080000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(FuncAttributes) + +//! Function detail - \ref CallConv and expanded \ref FuncSignature. +//! +//! Function detail is architecture and OS dependent representation of a function. It contains a materialized +//! calling convention and expanded function signature so all arguments have assigned either register type/id +//! or stack address. +class FuncDetail { +public: + //! \name Constants + //! \{ + + //! Function doesn't have a variable number of arguments (`...`). + static constexpr uint8_t kNoVarArgs = 0xFFu; + + //! \} + + //! \name Members + //! \{ + + //! Calling convention. + CallConv _callConv {}; + //! Number of function arguments. + uint8_t _argCount = 0; + //! Variable arguments index of `kNoVarArgs`. + uint8_t _vaIndex = 0; + //! Reserved for future use. + uint16_t _reserved = 0; + //! Registers that contain arguments. + Support::Array _usedRegs {}; + //! Size of arguments passed by stack. + uint32_t _argStackSize = 0; + //! Function return value(s). + FuncValuePack _rets {}; + //! Function arguments. + FuncValuePack _args[Globals::kMaxFuncArgs] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default constructed \ref FuncDetail. + ASMJIT_INLINE_NODEBUG FuncDetail() noexcept {} + + //! Copy constructor. + //! + //! Function details are copyable. + ASMJIT_INLINE_NODEBUG FuncDetail(const FuncDetail& other) noexcept = default; + + //! Initializes this `FuncDetail` to the given signature. + ASMJIT_API Error init(const FuncSignature& signature, const Environment& environment) noexcept; + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Assignment operator, copies `other` to this \ref FuncDetail. + ASMJIT_INLINE_NODEBUG FuncDetail& operator=(const FuncDetail& other) noexcept = default; + + //! \} + + //! \name Reset + //! \{ + + //! Resets the function detail to its default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = FuncDetail{}; } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the function's calling convention, see `CallConv`. + ASMJIT_INLINE_NODEBUG const CallConv& callConv() const noexcept { return _callConv; } + + //! Returns the associated calling convention flags, see `CallConv::Flags`. + ASMJIT_INLINE_NODEBUG CallConvFlags flags() const noexcept { return _callConv.flags(); } + //! Checks whether a CallConv `flag` is set, see `CallConv::Flags`. + ASMJIT_INLINE_NODEBUG bool hasFlag(CallConvFlags ccFlag) const noexcept { return _callConv.hasFlag(ccFlag); } + + //! Tests whether the function has a return value. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return bool(_rets[0]); } + //! Returns the number of function arguments. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _argCount; } + + //! Returns function return values. + ASMJIT_INLINE_NODEBUG FuncValuePack& retPack() noexcept { return _rets; } + //! Returns function return values. + ASMJIT_INLINE_NODEBUG const FuncValuePack& retPack() const noexcept { return _rets; } + + //! Returns a function return value associated with the given `valueIndex`. + ASMJIT_INLINE_NODEBUG FuncValue& ret(size_t valueIndex = 0) noexcept { return _rets[valueIndex]; } + //! Returns a function return value associated with the given `valueIndex` (const). + ASMJIT_INLINE_NODEBUG const FuncValue& ret(size_t valueIndex = 0) const noexcept { return _rets[valueIndex]; } + + //! Returns function argument packs array. + ASMJIT_INLINE_NODEBUG FuncValuePack* argPacks() noexcept { return _args; } + //! Returns function argument packs array (const). + ASMJIT_INLINE_NODEBUG const FuncValuePack* argPacks() const noexcept { return _args; } + + //! Returns function argument pack at the given `argIndex`. + inline FuncValuePack& argPack(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex]; + } + + //! Returns function argument pack at the given `argIndex` (const). + inline const FuncValuePack& argPack(size_t argIndex) const noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex]; + } + + //! Returns an argument at `valueIndex` from the argument pack at the given `argIndex`. + inline FuncValue& arg(size_t argIndex, size_t valueIndex = 0) noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex][valueIndex]; + } + + //! Returns an argument at `valueIndex` from the argument pack at the given `argIndex` (const). + inline const FuncValue& arg(size_t argIndex, size_t valueIndex = 0) const noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex][valueIndex]; + } + + //! Resets an argument at the given `argIndex`. + //! + //! If the argument is a parameter pack (has multiple values) all values are reset. + inline void resetArg(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + _args[argIndex].reset(); + } + + //! Tests whether the function has variable arguments. + ASMJIT_INLINE_NODEBUG bool hasVarArgs() const noexcept { return _vaIndex != kNoVarArgs; } + //! Returns an index of a first variable argument. + ASMJIT_INLINE_NODEBUG uint32_t vaIndex() const noexcept { return _vaIndex; } + + //! Tests whether the function passes one or more argument by stack. + ASMJIT_INLINE_NODEBUG bool hasStackArgs() const noexcept { return _argStackSize != 0; } + //! Returns stack size needed for function arguments passed on the stack. + ASMJIT_INLINE_NODEBUG uint32_t argStackSize() const noexcept { return _argStackSize; } + + //! Returns red zone size. + ASMJIT_INLINE_NODEBUG uint32_t redZoneSize() const noexcept { return _callConv.redZoneSize(); } + //! Returns spill zone size. + ASMJIT_INLINE_NODEBUG uint32_t spillZoneSize() const noexcept { return _callConv.spillZoneSize(); } + //! Returns natural stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t naturalStackAlignment() const noexcept { return _callConv.naturalStackAlignment(); } + + //! Returns a mask of all passed registers of the given register `group`. + ASMJIT_INLINE_NODEBUG RegMask passedRegs(RegGroup group) const noexcept { return _callConv.passedRegs(group); } + //! Returns a mask of all preserved registers of the given register `group`. + ASMJIT_INLINE_NODEBUG RegMask preservedRegs(RegGroup group) const noexcept { return _callConv.preservedRegs(group); } + + //! Returns a mask of all used registers of the given register `group`. + inline RegMask usedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _usedRegs[size_t(group)]; + } + + //! Adds `regs` to the mask of used registers of the given register `group`. + inline void addUsedRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _usedRegs[size_t(group)] |= regs; + } + + //! \} +}; + +//! Function frame. +//! +//! Function frame is used directly by prolog and epilog insertion (PEI) utils. It provides information necessary to +//! insert a proper and ABI conforming prolog and epilog. Function frame calculation is based on `CallConv` and +//! other function attributes. +//! +//! SSE vs AVX vs AVX-512 +//! --------------------- +//! +//! Function frame provides a way to tell prolog/epilog inserter to use AVX instructions instead of SSE. Use +//! `setAvxEnabled()` and `setAvx512Enabled()` to enable AVX and/or AVX-512, respectively. Enabling AVX-512 +//! is mostly for Compiler as it would use 32 SIMD registers instead of 16 when enabled. +//! +//! \note If your code uses AVX instructions and AVX is not enabled there would be a performance hit in case that +//! some registers had to be saved/restored in function's prolog/epilog, respectively. Thus, it's recommended to +//! always let the function frame know about the use of AVX. +//! +//! Function Frame Structure +//! ------------------------ +//! +//! Various properties can contribute to the size and structure of the function frame. The function frame in most +//! cases won't use all of the properties illustrated (for example Spill Zone and Red Zone are never used together). +//! +//! ``` +//! +-----------------------------+ +//! | Arguments Passed by Stack | +//! +-----------------------------+ +//! | Spill Zone | +//! +-----------------------------+ <- Stack offset (args) starts from here. +//! | Return Address, if Pushed | +//! +-----------------------------+ <- Stack pointer (SP) upon entry. +//! | Save/Restore Stack. | +//! +-----------------------------+-----------------------------+ +//! | Local Stack | | +//! +-----------------------------+ Final Stack | +//! | Call Stack | | +//! +-----------------------------+-----------------------------+ <- SP after prolog. +//! | Red Zone | +//! +-----------------------------+ +//! ``` +class FuncFrame { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + //! Tag used to inform that some offset is invalid. + kTagInvalidOffset = 0xFFFFFFFFu + }; + + //! \} + + //! \name Members + //! \{ + + //! Function attributes. + FuncAttributes _attributes {}; + + //! Target architecture. + Arch _arch {}; + //! SP register ID (to access call stack and local stack). + uint8_t _spRegId = uint8_t(BaseReg::kIdBad); + //! SA register ID (to access stack arguments). + uint8_t _saRegId = uint8_t(BaseReg::kIdBad); + + //! Red zone size (copied from CallConv). + uint8_t _redZoneSize = 0; + //! Spill zone size (copied from CallConv). + uint8_t _spillZoneSize = 0; + //! Natural stack alignment (copied from CallConv). + uint8_t _naturalStackAlignment = 0; + //! Minimum stack alignment to turn on dynamic alignment. + uint8_t _minDynamicAlignment = 0; + + //! Call stack alignment. + uint8_t _callStackAlignment = 0; + //! Local stack alignment. + uint8_t _localStackAlignment = 0; + //! Final stack alignment. + uint8_t _finalStackAlignment = 0; + + //! Adjustment of the stack before returning (X86-STDCALL). + uint16_t _calleeStackCleanup = 0; + + //! Call stack size. + uint32_t _callStackSize = 0; + //! Local stack size. + uint32_t _localStackSize = 0; + //! Final stack size (sum of call stack and local stack). + uint32_t _finalStackSize = 0; + + //! Local stack offset (non-zero only if call stack is used). + uint32_t _localStackOffset = 0; + //! Offset relative to SP that contains previous SP (before alignment). + uint32_t _daOffset = 0; + //! Offset of the first stack argument relative to SP. + uint32_t _saOffsetFromSP = 0; + //! Offset of the first stack argument relative to SA (_saRegId or FP). + uint32_t _saOffsetFromSA = 0; + + //! Local stack adjustment in prolog/epilog. + uint32_t _stackAdjustment = 0; + + //! Registers that are dirty. + Support::Array _dirtyRegs {}; + //! Registers that must be preserved (copied from CallConv). + Support::Array _preservedRegs {}; + //! Size to save/restore per register group. + Support::Array _saveRestoreRegSize {}; + //! Alignment of save/restore area per register group. + Support::Array _saveRestoreAlignment {}; + + //! Stack size required to save registers with push/pop. + uint16_t _pushPopSaveSize = 0; + //! Stack size required to save extra registers that cannot use push/pop. + uint16_t _extraRegSaveSize = 0; + //! Offset where registers saved/restored via push/pop are stored + uint32_t _pushPopSaveOffset = 0; + //! Offset where extra registers that cannot use push/pop are stored. + uint32_t _extraRegSaveOffset = 0; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default constructed function frame, which has initialized all members to their default values. + ASMJIT_INLINE_NODEBUG FuncFrame() noexcept = default; + //! Creates a copy of `other` function frame. + ASMJIT_INLINE_NODEBUG FuncFrame(const FuncFrame& other) noexcept = default; + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Initializes the function frame based on `func` detail. + ASMJIT_API Error init(const FuncDetail& func) noexcept; + //! Resets the function frame into its default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = FuncFrame{}; } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment - function frame is copy assignable. + ASMJIT_INLINE_NODEBUG FuncFrame& operator=(const FuncFrame& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the target architecture of the function frame. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + + //! Returns function frame attributes, see `Attributes`. + ASMJIT_INLINE_NODEBUG FuncAttributes attributes() const noexcept { return _attributes; } + //! Checks whether the FuncFame contains an attribute `attr`. + ASMJIT_INLINE_NODEBUG bool hasAttribute(FuncAttributes attr) const noexcept { return Support::test(_attributes, attr); } + //! Adds attributes `attrs` to the FuncFrame. + ASMJIT_INLINE_NODEBUG void addAttributes(FuncAttributes attrs) noexcept { _attributes |= attrs; } + //! Clears attributes `attrs` from the FrameFrame. + ASMJIT_INLINE_NODEBUG void clearAttributes(FuncAttributes attrs) noexcept { _attributes &= ~attrs; } + + //! Tests whether the function has variable number of arguments. + ASMJIT_INLINE_NODEBUG bool hasVarArgs() const noexcept { return hasAttribute(FuncAttributes::kHasVarArgs); } + //! Sets the variable arguments flag. + ASMJIT_INLINE_NODEBUG void setVarArgs() noexcept { addAttributes(FuncAttributes::kHasVarArgs); } + //! Resets variable arguments flag. + ASMJIT_INLINE_NODEBUG void resetVarArgs() noexcept { clearAttributes(FuncAttributes::kHasVarArgs); } + + //! Tests whether the function preserves frame pointer (EBP|ESP on X86). + ASMJIT_INLINE_NODEBUG bool hasPreservedFP() const noexcept { return hasAttribute(FuncAttributes::kHasPreservedFP); } + //! Enables preserved frame pointer. + ASMJIT_INLINE_NODEBUG void setPreservedFP() noexcept { addAttributes(FuncAttributes::kHasPreservedFP); } + //! Disables preserved frame pointer. + ASMJIT_INLINE_NODEBUG void resetPreservedFP() noexcept { clearAttributes(FuncAttributes::kHasPreservedFP); } + + //! Tests whether the function calls other functions. + ASMJIT_INLINE_NODEBUG bool hasFuncCalls() const noexcept { return hasAttribute(FuncAttributes::kHasFuncCalls); } + //! Sets `FuncAttributes::kHasFuncCalls` to true. + ASMJIT_INLINE_NODEBUG void setFuncCalls() noexcept { addAttributes(FuncAttributes::kHasFuncCalls); } + //! Sets `FuncAttributes::kHasFuncCalls` to false. + ASMJIT_INLINE_NODEBUG void resetFuncCalls() noexcept { clearAttributes(FuncAttributes::kHasFuncCalls); } + + //! Tests whether the function uses indirect branch protection, see \ref FuncAttributes::kIndirectBranchProtection. + ASMJIT_INLINE_NODEBUG bool hasIndirectBranchProtection() const noexcept { return hasAttribute(FuncAttributes::kIndirectBranchProtection); } + //! Enabled indirect branch protection (sets `FuncAttributes::kIndirectBranchProtection` attribute to true). + ASMJIT_INLINE_NODEBUG void setIndirectBranchProtection() noexcept { addAttributes(FuncAttributes::kIndirectBranchProtection); } + //! Disables indirect branch protection (sets `FuncAttributes::kIndirectBranchProtection` attribute to false). + ASMJIT_INLINE_NODEBUG void resetIndirectBranchProtection() noexcept { clearAttributes(FuncAttributes::kIndirectBranchProtection); } + + //! Tests whether the function has AVX enabled. + ASMJIT_INLINE_NODEBUG bool isAvxEnabled() const noexcept { return hasAttribute(FuncAttributes::kX86_AVXEnabled); } + //! Enables AVX use. + ASMJIT_INLINE_NODEBUG void setAvxEnabled() noexcept { addAttributes(FuncAttributes::kX86_AVXEnabled); } + //! Disables AVX use. + ASMJIT_INLINE_NODEBUG void resetAvxEnabled() noexcept { clearAttributes(FuncAttributes::kX86_AVXEnabled); } + + //! Tests whether the function has AVX-512 enabled. + ASMJIT_INLINE_NODEBUG bool isAvx512Enabled() const noexcept { return hasAttribute(FuncAttributes::kX86_AVX512Enabled); } + //! Enables AVX-512 use. + ASMJIT_INLINE_NODEBUG void setAvx512Enabled() noexcept { addAttributes(FuncAttributes::kX86_AVX512Enabled); } + //! Disables AVX-512 use. + ASMJIT_INLINE_NODEBUG void resetAvx512Enabled() noexcept { clearAttributes(FuncAttributes::kX86_AVX512Enabled); } + + //! Tests whether the function has MMX cleanup - 'emms' instruction in epilog. + ASMJIT_INLINE_NODEBUG bool hasMmxCleanup() const noexcept { return hasAttribute(FuncAttributes::kX86_MMXCleanup); } + //! Enables MMX cleanup. + ASMJIT_INLINE_NODEBUG void setMmxCleanup() noexcept { addAttributes(FuncAttributes::kX86_MMXCleanup); } + //! Disables MMX cleanup. + ASMJIT_INLINE_NODEBUG void resetMmxCleanup() noexcept { clearAttributes(FuncAttributes::kX86_MMXCleanup); } + + //! Tests whether the function has AVX cleanup - 'vzeroupper' instruction in epilog. + ASMJIT_INLINE_NODEBUG bool hasAvxCleanup() const noexcept { return hasAttribute(FuncAttributes::kX86_AVXCleanup); } + //! Enables AVX cleanup. + ASMJIT_INLINE_NODEBUG void setAvxCleanup() noexcept { addAttributes(FuncAttributes::kX86_AVXCleanup); } + //! Disables AVX cleanup. + ASMJIT_INLINE_NODEBUG void resetAvxCleanup() noexcept { clearAttributes(FuncAttributes::kX86_AVXCleanup); } + + //! Tests whether the function uses call stack. + ASMJIT_INLINE_NODEBUG bool hasCallStack() const noexcept { return _callStackSize != 0; } + //! Tests whether the function uses local stack. + ASMJIT_INLINE_NODEBUG bool hasLocalStack() const noexcept { return _localStackSize != 0; } + //! Tests whether vector registers can be saved and restored by using aligned reads and writes. + ASMJIT_INLINE_NODEBUG bool hasAlignedVecSR() const noexcept { return hasAttribute(FuncAttributes::kAlignedVecSR); } + //! Tests whether the function has to align stack dynamically. + ASMJIT_INLINE_NODEBUG bool hasDynamicAlignment() const noexcept { return _finalStackAlignment >= _minDynamicAlignment; } + + //! Tests whether the calling convention specifies 'RedZone'. + ASMJIT_INLINE_NODEBUG bool hasRedZone() const noexcept { return _redZoneSize != 0; } + //! Tests whether the calling convention specifies 'SpillZone'. + ASMJIT_INLINE_NODEBUG bool hasSpillZone() const noexcept { return _spillZoneSize != 0; } + + //! Returns the size of 'RedZone'. + ASMJIT_INLINE_NODEBUG uint32_t redZoneSize() const noexcept { return _redZoneSize; } + //! Returns the size of 'SpillZone'. + ASMJIT_INLINE_NODEBUG uint32_t spillZoneSize() const noexcept { return _spillZoneSize; } + + //! Resets the size of red zone, which would disable it entirely. + //! + //! \note Red zone is currently only used by an AMD64 SystemV calling convention, which expects 128 + //! bytes of stack to be accessible below stack pointer. These bytes are then accessible within the + //! function and Compiler can use this space as a spill area. However, sometimes it's better to + //! disallow the use of red zone in case that a user wants to use this stack for a custom purpose. + ASMJIT_INLINE_NODEBUG void resetRedZone() noexcept { _redZoneSize = 0; } + + //! Returns natural stack alignment (guaranteed stack alignment upon entry). + ASMJIT_INLINE_NODEBUG uint32_t naturalStackAlignment() const noexcept { return _naturalStackAlignment; } + //! Returns natural stack alignment (guaranteed stack alignment upon entry). + ASMJIT_INLINE_NODEBUG uint32_t minDynamicAlignment() const noexcept { return _minDynamicAlignment; } + + //! Tests whether the callee must adjust SP before returning (X86-STDCALL only) + ASMJIT_INLINE_NODEBUG bool hasCalleeStackCleanup() const noexcept { return _calleeStackCleanup != 0; } + //! Returns home many bytes of the stack the callee must adjust before returning (X86-STDCALL only) + ASMJIT_INLINE_NODEBUG uint32_t calleeStackCleanup() const noexcept { return _calleeStackCleanup; } + + //! Returns call stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t callStackAlignment() const noexcept { return _callStackAlignment; } + //! Returns local stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t localStackAlignment() const noexcept { return _localStackAlignment; } + //! Returns final stack alignment (the maximum value of call, local, and natural stack alignments). + ASMJIT_INLINE_NODEBUG uint32_t finalStackAlignment() const noexcept { return _finalStackAlignment; } + + //! Sets call stack alignment. + //! + //! \note This also updates the final stack alignment. + inline void setCallStackAlignment(uint32_t alignment) noexcept { + _callStackAlignment = uint8_t(alignment); + _finalStackAlignment = Support::max(_naturalStackAlignment, _callStackAlignment, _localStackAlignment); + } + + //! Sets local stack alignment. + //! + //! \note This also updates the final stack alignment. + inline void setLocalStackAlignment(uint32_t value) noexcept { + _localStackAlignment = uint8_t(value); + _finalStackAlignment = Support::max(_naturalStackAlignment, _callStackAlignment, _localStackAlignment); + } + + //! Combines call stack alignment with `alignment`, updating it to the greater value. + //! + //! \note This also updates the final stack alignment. + inline void updateCallStackAlignment(uint32_t alignment) noexcept { + _callStackAlignment = uint8_t(Support::max(_callStackAlignment, alignment)); + _finalStackAlignment = Support::max(_finalStackAlignment, _callStackAlignment); + } + + //! Combines local stack alignment with `alignment`, updating it to the greater value. + //! + //! \note This also updates the final stack alignment. + inline void updateLocalStackAlignment(uint32_t alignment) noexcept { + _localStackAlignment = uint8_t(Support::max(_localStackAlignment, alignment)); + _finalStackAlignment = Support::max(_finalStackAlignment, _localStackAlignment); + } + + //! Returns call stack size. + ASMJIT_INLINE_NODEBUG uint32_t callStackSize() const noexcept { return _callStackSize; } + //! Returns local stack size. + ASMJIT_INLINE_NODEBUG uint32_t localStackSize() const noexcept { return _localStackSize; } + + //! Sets call stack size. + ASMJIT_INLINE_NODEBUG void setCallStackSize(uint32_t size) noexcept { _callStackSize = size; } + //! Sets local stack size. + ASMJIT_INLINE_NODEBUG void setLocalStackSize(uint32_t size) noexcept { _localStackSize = size; } + + //! Combines call stack size with `size`, updating it to the greater value. + ASMJIT_INLINE_NODEBUG void updateCallStackSize(uint32_t size) noexcept { _callStackSize = Support::max(_callStackSize, size); } + //! Combines local stack size with `size`, updating it to the greater value. + ASMJIT_INLINE_NODEBUG void updateLocalStackSize(uint32_t size) noexcept { _localStackSize = Support::max(_localStackSize, size); } + + //! Returns final stack size (only valid after the FuncFrame is finalized). + ASMJIT_INLINE_NODEBUG uint32_t finalStackSize() const noexcept { return _finalStackSize; } + + //! Returns an offset to access the local stack (non-zero only if call stack is used). + ASMJIT_INLINE_NODEBUG uint32_t localStackOffset() const noexcept { return _localStackOffset; } + + //! Tests whether the function prolog/epilog requires a memory slot for storing unaligned SP. + ASMJIT_INLINE_NODEBUG bool hasDAOffset() const noexcept { return _daOffset != kTagInvalidOffset; } + //! Returns a memory offset used to store DA (dynamic alignment) slot (relative to SP). + ASMJIT_INLINE_NODEBUG uint32_t daOffset() const noexcept { return _daOffset; } + + ASMJIT_INLINE_NODEBUG uint32_t saOffset(uint32_t regId) const noexcept { + return regId == _spRegId ? saOffsetFromSP() + : saOffsetFromSA(); + } + + ASMJIT_INLINE_NODEBUG uint32_t saOffsetFromSP() const noexcept { return _saOffsetFromSP; } + ASMJIT_INLINE_NODEBUG uint32_t saOffsetFromSA() const noexcept { return _saOffsetFromSA; } + + //! Returns mask of registers of the given register `group` that are modified by the function. The engine would + //! then calculate which registers must be saved & restored by the function by using the data provided by the + //! calling convention. + inline RegMask dirtyRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _dirtyRegs[group]; + } + + //! Sets which registers (as a mask) are modified by the function. + //! + //! \remarks Please note that this will completely overwrite the existing register mask, use `addDirtyRegs()` + //! to modify the existing register mask. + inline void setDirtyRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _dirtyRegs[group] = regs; + } + + //! Adds which registers (as a mask) are modified by the function. + inline void addDirtyRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _dirtyRegs[group] |= regs; + } + + //! \overload + inline void addDirtyRegs(const BaseReg& reg) noexcept { + ASMJIT_ASSERT(reg.id() < Globals::kMaxPhysRegs); + addDirtyRegs(reg.group(), Support::bitMask(reg.id())); + } + + //! \overload + template + inline void addDirtyRegs(const BaseReg& reg, Args&&... args) noexcept { + addDirtyRegs(reg); + addDirtyRegs(std::forward(args)...); + } + + //! A helper function to set all registers from all register groups dirty. + //! + //! \note This should not be used in general as it's the most pessimistic case. However, it can be used for testing + //! or in cases in which all registers are considered clobbered. + ASMJIT_INLINE_NODEBUG void setAllDirty() noexcept { + for (size_t i = 0; i < ASMJIT_ARRAY_SIZE(_dirtyRegs); i++) + _dirtyRegs[i] = 0xFFFFFFFFu; + } + + //! A helper function to set all registers from the given register `group` dirty. + inline void setAllDirty(RegGroup group) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _dirtyRegs[group] = 0xFFFFFFFFu; + } + + //! Returns a calculated mask of registers of the given `group` that will be saved and restored in the function's + //! prolog and epilog, respectively. The register mask is calculated from both `dirtyRegs` (provided by user) and + //! `preservedMask` (provided by the calling convention). + inline RegMask savedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _dirtyRegs[group] & _preservedRegs[group]; + } + + //! Returns the mask of preserved registers of the given register `group`. + //! + //! Preserved registers are those that must survive the function call unmodified. The function can only modify + //! preserved registers it they are saved and restored in function's prolog and epilog, respectively. + inline RegMask preservedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _preservedRegs[group]; + } + + //! Returns the size of a save-restore are for the required register `group`. + inline uint32_t saveRestoreRegSize(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _saveRestoreRegSize[group]; + } + + inline uint32_t saveRestoreAlignment(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _saveRestoreAlignment[group]; + } + + ASMJIT_INLINE_NODEBUG bool hasSARegId() const noexcept { return _saRegId != BaseReg::kIdBad; } + ASMJIT_INLINE_NODEBUG uint32_t saRegId() const noexcept { return _saRegId; } + ASMJIT_INLINE_NODEBUG void setSARegId(uint32_t regId) { _saRegId = uint8_t(regId); } + ASMJIT_INLINE_NODEBUG void resetSARegId() { setSARegId(BaseReg::kIdBad); } + + //! Returns stack size required to save/restore registers via push/pop. + ASMJIT_INLINE_NODEBUG uint32_t pushPopSaveSize() const noexcept { return _pushPopSaveSize; } + //! Returns an offset to the stack where registers are saved via push/pop. + ASMJIT_INLINE_NODEBUG uint32_t pushPopSaveOffset() const noexcept { return _pushPopSaveOffset; } + + //! Returns stack size required to save/restore extra registers that don't use push/pop/ + //! + //! \note On X86 this covers all registers except GP registers, on other architectures it can be always + //! zero (for example AArch64 saves all registers via push/pop like instructions, so this would be zero). + ASMJIT_INLINE_NODEBUG uint32_t extraRegSaveSize() const noexcept { return _extraRegSaveSize; } + //! Returns an offset to the stack where extra registers are saved. + ASMJIT_INLINE_NODEBUG uint32_t extraRegSaveOffset() const noexcept { return _extraRegSaveOffset; } + + //! Tests whether the functions contains stack adjustment. + ASMJIT_INLINE_NODEBUG bool hasStackAdjustment() const noexcept { return _stackAdjustment != 0; } + //! Returns function's stack adjustment used in function's prolog and epilog. + //! + //! If the returned value is zero it means that the stack is not adjusted. This can mean both that the stack + //! is not used and/or the stack is only adjusted by instructions that pust/pop registers into/from stack. + ASMJIT_INLINE_NODEBUG uint32_t stackAdjustment() const noexcept { return _stackAdjustment; } + + //! \} + + //! \name Finalization + //! \{ + + ASMJIT_API Error finalize() noexcept; + + //! \} +}; + +//! A helper class that can be used to assign a physical register for each function argument. Use with +//! `BaseEmitter::emitArgsAssignment()`. +class FuncArgsAssignment { +public: + //! \name Members + //! \{ + + //! Function detail. + const FuncDetail* _funcDetail {}; + //! Register that can be used to access arguments passed by stack. + uint8_t _saRegId = uint8_t(BaseReg::kIdBad); + //! Reserved for future use. + uint8_t _reserved[3] {}; + //! Mapping of each function argument. + FuncValuePack _argPacks[Globals::kMaxFuncArgs] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates either a default initialized `FuncArgsAssignment` or to assignment that links to `fd`, if non-null. + ASMJIT_INLINE_NODEBUG explicit FuncArgsAssignment(const FuncDetail* fd = nullptr) noexcept { reset(fd); } + + //! Copy constructor. + ASMJIT_INLINE_NODEBUG FuncArgsAssignment(const FuncArgsAssignment& other) noexcept = default; + + //! Resets this `FuncArgsAssignment` to either default constructed state or to assignment that links to `fd`, + //! if non-null. + inline void reset(const FuncDetail* fd = nullptr) noexcept { + _funcDetail = fd; + _saRegId = uint8_t(BaseReg::kIdBad); + memset(_reserved, 0, sizeof(_reserved)); + memset(_argPacks, 0, sizeof(_argPacks)); + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment. + ASMJIT_INLINE_NODEBUG FuncArgsAssignment& operator=(const FuncArgsAssignment& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the associated \ref FuncDetail of this `FuncArgsAssignment`. + ASMJIT_INLINE_NODEBUG const FuncDetail* funcDetail() const noexcept { return _funcDetail; } + //! Associates \ref FuncDetails with this `FuncArgsAssignment`. + ASMJIT_INLINE_NODEBUG void setFuncDetail(const FuncDetail* fd) noexcept { _funcDetail = fd; } + + ASMJIT_INLINE_NODEBUG bool hasSARegId() const noexcept { return _saRegId != BaseReg::kIdBad; } + ASMJIT_INLINE_NODEBUG uint32_t saRegId() const noexcept { return _saRegId; } + ASMJIT_INLINE_NODEBUG void setSARegId(uint32_t regId) { _saRegId = uint8_t(regId); } + ASMJIT_INLINE_NODEBUG void resetSARegId() { _saRegId = uint8_t(BaseReg::kIdBad); } + + //! Returns assigned argument at `argIndex` and `valueIndex`. + //! + //! \note `argIndex` refers to he function argument and `valueIndex` refers to a value pack (in case multiple + //! values are passed as a single argument). + inline FuncValue& arg(size_t argIndex, size_t valueIndex) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + return _argPacks[argIndex][valueIndex]; + } + //! \overload + inline const FuncValue& arg(size_t argIndex, size_t valueIndex) const noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + return _argPacks[argIndex][valueIndex]; + } + + //! Tests whether argument at `argIndex` and `valueIndex` has been assigned. + inline bool isAssigned(size_t argIndex, size_t valueIndex) const noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + return _argPacks[argIndex][valueIndex].isAssigned(); + } + + //! Assigns register at `argIndex` and value index of 0 to `reg` and an optional `typeId`. + inline void assignReg(size_t argIndex, const BaseReg& reg, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + ASMJIT_ASSERT(reg.isPhysReg()); + _argPacks[argIndex][0].initReg(reg.type(), reg.id(), typeId); + } + + //! Assigns register at `argIndex` and value index of 0 to `regType`, `regId`, and an optional `typeId`. + inline void assignReg(size_t argIndex, RegType regType, uint32_t regId, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][0].initReg(regType, regId, typeId); + } + + //! Assigns stack at `argIndex` and value index of 0 to `offset` and an optional `typeId`. + inline void assignStack(size_t argIndex, int32_t offset, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][0].initStack(offset, typeId); + } + + //! Assigns register at `argIndex` and `valueIndex` to `reg` and an optional `typeId`. + inline void assignRegInPack(size_t argIndex, size_t valueIndex, const BaseReg& reg, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + ASMJIT_ASSERT(reg.isPhysReg()); + _argPacks[argIndex][valueIndex].initReg(reg.type(), reg.id(), typeId); + } + + //! Assigns register at `argIndex` and `valueIndex` to `regType`, `regId`, and an optional `typeId`. + inline void assignRegInPack(size_t argIndex, size_t valueIndex, RegType regType, uint32_t regId, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][valueIndex].initReg(regType, regId, typeId); + } + + //! Assigns stack at `argIndex` and `valueIndex` to `offset` and an optional `typeId`. + inline void assignStackInPack(size_t argIndex, size_t valueIndex, int32_t offset, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][valueIndex].initStack(offset, typeId); + } + + // NOTE: All `assignAll()` methods are shortcuts to assign all arguments at once, however, since registers are + // passed all at once these initializers don't provide any way to pass TypeId and/or to keep any argument between + // the arguments passed unassigned. + inline void _assignAllInternal(size_t argIndex, const BaseReg& reg) noexcept { + assignReg(argIndex, reg); + } + + template + inline void _assignAllInternal(size_t argIndex, const BaseReg& reg, Args&&... args) noexcept { + assignReg(argIndex, reg); + _assignAllInternal(argIndex + 1, std::forward(args)...); + } + + //! Assigns all argument at once. + //! + //! \note This function can be only used if the arguments don't contain value packs (multiple values per argument). + template + inline void assignAll(Args&&... args) noexcept { + _assignAllInternal(0, std::forward(args)...); + } + + //! \} + + //! \name Utilities + //! \{ + + //! Update `FuncFrame` based on function's arguments assignment. + //! + //! \note This function must be called in order to use `BaseEmitter::emitArgsAssignment()`, otherwise the \ref FuncFrame + //! would not contain the information necessary to assign all arguments into the registers and/or stack specified. + ASMJIT_API Error updateFuncFrame(FuncFrame& frame) const noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_FUNC_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/globals.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/globals.h new file mode 100644 index 0000000000000000000000000000000000000000..db921cfc6e1c4535594cdc5e0a212fa0e6fb5be8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/globals.h @@ -0,0 +1,421 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_GLOBALS_H_INCLUDED +#define ASMJIT_CORE_GLOBALS_H_INCLUDED + +#include "../core/api-config.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \cond INTERNAL +//! \addtogroup asmjit_utilities +//! \{ +namespace Support { + +//! Cast designed to cast between function and void* pointers. +template +static inline Dst ptr_cast_impl(Src p) noexcept { return (Dst)p; } + +//! Helper to implement placement new/delete without relying on `` header. +struct PlacementNew { void* ptr; }; + +} // {Support} + +#if defined(ASMJIT_NO_STDCXX) +namespace Support { + ASMJIT_FORCE_INLINE void* operatorNew(size_t n) noexcept { return malloc(n); } + ASMJIT_FORCE_INLINE void operatorDelete(void* p) noexcept { if (p) free(p); } +} // {Support} + +#define ASMJIT_BASE_CLASS(TYPE) \ + ASMJIT_FORCE_INLINE void* operator new(size_t n) noexcept { return Support::operatorNew(n); } \ + ASMJIT_FORCE_INLINE void operator delete(void* ptr) noexcept { Support::operatorDelete(ptr); } \ + \ + ASMJIT_FORCE_INLINE void* operator new(size_t, void* ptr) noexcept { return ptr; } \ + ASMJIT_FORCE_INLINE void operator delete(void*, void*) noexcept {} \ + \ + ASMJIT_FORCE_INLINE void* operator new(size_t, Support::PlacementNew ptr) noexcept { return ptr.ptr; } \ + ASMJIT_FORCE_INLINE void operator delete(void*, Support::PlacementNew) noexcept {} +#else +#define ASMJIT_BASE_CLASS(TYPE) +#endif + +//! \} +//! \endcond + +//! \addtogroup asmjit_core +//! \{ + +//! Byte order. +enum class ByteOrder { + //! Little endian. + kLE = 0, + //! Big endian. + kBE = 1, + //! Native byte order of the target architecture. + kNative = ASMJIT_ARCH_LE ? kLE : kBE, + //! Swapped byte order of the target architecture. + kSwapped = ASMJIT_ARCH_LE ? kBE : kLE +}; + +//! A policy that can be used with some `reset()` member functions. +enum class ResetPolicy : uint32_t { + //! Soft reset, doesn't deallocate memory (default). + kSoft = 0, + //! Hard reset, releases all memory used, if any. + kHard = 1 +}; + +//! Contains typedefs, constants, and variables used globally by AsmJit. +namespace Globals { + +//! Host memory allocator overhead. +static constexpr uint32_t kAllocOverhead = uint32_t(sizeof(intptr_t) * 4); + +//! Host memory allocator alignment. +static constexpr uint32_t kAllocAlignment = 8; + +//! Aggressive growing strategy threshold. +static constexpr uint32_t kGrowThreshold = 1024 * 1024 * 16; + +//! Maximum depth of RB-Tree is: +//! +//! `2 * log2(n + 1)` +//! +//! Size of RB node is at least two pointers (without data), so a theoretical architecture limit would be: +//! +//! `2 * log2(addressableMemorySize / sizeof(Node) + 1)` +//! +//! Which yields 30 on 32-bit arch and 61 on 64-bit arch. The final value was adjusted by +1 for safety reasons. +static constexpr uint32_t kMaxTreeHeight = (ASMJIT_ARCH_BITS == 32 ? 30 : 61) + 1; + +//! Maximum number of operands per a single instruction. +static constexpr uint32_t kMaxOpCount = 6; + +//! Maximum arguments of a function supported by the Compiler / Function API. +static constexpr uint32_t kMaxFuncArgs = 32; + +//! The number of values that can be assigned to a single function argument or return value. +static constexpr uint32_t kMaxValuePack = 4; + +//! Maximum number of physical registers AsmJit can use per register group. +static constexpr uint32_t kMaxPhysRegs = 32; + +//! Maximum alignment. +static constexpr uint32_t kMaxAlignment = 64; + +//! Maximum label or symbol size in bytes. +static constexpr uint32_t kMaxLabelNameSize = 2048; + +//! Maximum section name size. +static constexpr uint32_t kMaxSectionNameSize = 35; + +//! Maximum size of comment. +static constexpr uint32_t kMaxCommentSize = 1024; + +//! Invalid identifier. +static constexpr uint32_t kInvalidId = 0xFFFFFFFFu; + +//! Returned by `indexOf()` and similar when working with containers that use 32-bit index/size. +static constexpr uint32_t kNotFound = 0xFFFFFFFFu; + +//! Invalid base address. +static constexpr uint64_t kNoBaseAddress = ~uint64_t(0); + +//! Number of virtual register groups. +static constexpr uint32_t kNumVirtGroups = 4; + +struct Init_ {}; +struct NoInit_ {}; + +//! A decorator used to initialize. +static const constexpr Init_ Init {}; +//! A decorator used to not initialize. +static const constexpr NoInit_ NoInit {}; + +} // {Globals} + +//! Casts a `void*` pointer `func` to a function pointer `Func`. +template +static ASMJIT_INLINE_NODEBUG Func ptr_as_func(void* func) noexcept { return Support::ptr_cast_impl(func); } + +//! Casts a function pointer `func` to a void pointer `void*`. +template +static ASMJIT_INLINE_NODEBUG void* func_as_ptr(Func func) noexcept { return Support::ptr_cast_impl(func); } + +//! \} + +//! \addtogroup asmjit_error_handling +//! \{ + +//! AsmJit error type (uint32_t). +typedef uint32_t Error; + +//! AsmJit error codes. +enum ErrorCode : uint32_t { + // @EnumValuesBegin{"enum": "ErrorCode"}@ + + //! No error (success). + kErrorOk = 0, + + //! Out of memory. + kErrorOutOfMemory, + + //! Invalid argument. + kErrorInvalidArgument, + + //! Invalid state. + //! + //! If this error is returned it means that either you are doing something wrong or AsmJit caught itself by + //! doing something wrong. This error should never be ignored. + kErrorInvalidState, + + //! Invalid or incompatible architecture. + kErrorInvalidArch, + + //! The object is not initialized. + kErrorNotInitialized, + //! The object is already initialized. + kErrorAlreadyInitialized, + + //! Either a built-in feature was disabled at compile time and it's not available or the feature is not + //! available on the target platform. + //! + //! For example trying to allocate large pages on unsupported platform would return this error. + kErrorFeatureNotEnabled, + + //! Too many handles (Windows) or file descriptors (Unix/Posix). + kErrorTooManyHandles, + //! Code generated is larger than allowed. + kErrorTooLarge, + + //! No code generated. + //! + //! Returned by runtime if the \ref CodeHolder contains no code. + kErrorNoCodeGenerated, + + //! Invalid directive. + kErrorInvalidDirective, + //! Attempt to use uninitialized label. + kErrorInvalidLabel, + //! Label index overflow - a single \ref BaseAssembler instance can hold almost 2^32 (4 billion) labels. If + //! there is an attempt to create more labels then this error is returned. + kErrorTooManyLabels, + //! Label is already bound. + kErrorLabelAlreadyBound, + //! Label is already defined (named labels). + kErrorLabelAlreadyDefined, + //! Label name is too long. + kErrorLabelNameTooLong, + //! Label must always be local if it's anonymous (without a name). + kErrorInvalidLabelName, + //! Parent id passed to \ref CodeHolder::newNamedLabelEntry() was either invalid or parent is not supported + //! by the requested `LabelType`. + kErrorInvalidParentLabel, + + //! Invalid section. + kErrorInvalidSection, + //! Too many sections (section index overflow). + kErrorTooManySections, + //! Invalid section name (most probably too long). + kErrorInvalidSectionName, + + //! Relocation index overflow (too many relocations). + kErrorTooManyRelocations, + //! Invalid relocation entry. + kErrorInvalidRelocEntry, + //! Reloc entry contains address that is out of range (unencodable). + kErrorRelocOffsetOutOfRange, + + //! Invalid assignment to a register, function argument, or function return value. + kErrorInvalidAssignment, + //! Invalid instruction. + kErrorInvalidInstruction, + //! Invalid register type. + kErrorInvalidRegType, + //! Invalid register group. + kErrorInvalidRegGroup, + //! Invalid physical register id. + kErrorInvalidPhysId, + //! Invalid virtual register id. + kErrorInvalidVirtId, + //! Invalid element index (ARM). + kErrorInvalidElementIndex, + //! Invalid prefix combination (X86|X64). + kErrorInvalidPrefixCombination, + //! Invalid LOCK prefix (X86|X64). + kErrorInvalidLockPrefix, + //! Invalid XACQUIRE prefix (X86|X64). + kErrorInvalidXAcquirePrefix, + //! Invalid XRELEASE prefix (X86|X64). + kErrorInvalidXReleasePrefix, + //! Invalid REP prefix (X86|X64). + kErrorInvalidRepPrefix, + //! Invalid REX prefix (X86|X64). + kErrorInvalidRexPrefix, + //! Invalid {...} register (X86|X64). + kErrorInvalidExtraReg, + //! Invalid {k} use (not supported by the instruction) (X86|X64). + kErrorInvalidKMaskUse, + //! Invalid {k}{z} use (not supported by the instruction) (X86|X64). + kErrorInvalidKZeroUse, + //! Invalid broadcast - Currently only related to invalid use of AVX-512 {1tox} (X86|X64). + kErrorInvalidBroadcast, + //! Invalid 'embedded-rounding' {er} or 'suppress-all-exceptions' {sae} (AVX-512) (X86|X64). + kErrorInvalidEROrSAE, + //! Invalid address used (not encodable). + kErrorInvalidAddress, + //! Invalid index register used in memory address (not encodable). + kErrorInvalidAddressIndex, + //! Invalid address scale (not encodable). + kErrorInvalidAddressScale, + //! Invalid use of 64-bit address. + kErrorInvalidAddress64Bit, + //! Invalid use of 64-bit address that require 32-bit zero-extension (X64). + kErrorInvalidAddress64BitZeroExtension, + //! Invalid displacement (not encodable). + kErrorInvalidDisplacement, + //! Invalid segment (X86). + kErrorInvalidSegment, + + //! Invalid immediate (out of bounds on X86 and invalid pattern on ARM). + kErrorInvalidImmediate, + + //! Invalid operand size. + kErrorInvalidOperandSize, + //! Ambiguous operand size (memory has zero size while it's required to determine the operation type. + kErrorAmbiguousOperandSize, + //! Mismatching operand size (size of multiple operands doesn't match the operation size). + kErrorOperandSizeMismatch, + + //! Invalid option. + kErrorInvalidOption, + //! Option already defined. + kErrorOptionAlreadyDefined, + + //! Invalid TypeId. + kErrorInvalidTypeId, + //! Invalid use of a 8-bit GPB-HIGH register. + kErrorInvalidUseOfGpbHi, + //! Invalid use of a 64-bit GPQ register in 32-bit mode. + kErrorInvalidUseOfGpq, + //! Invalid use of an 80-bit float (\ref TypeId::kFloat80). + kErrorInvalidUseOfF80, + //! Instruction requires the use of consecutive registers, but registers in operands weren't (AVX512, ASIMD load/store, etc...). + kErrorNotConsecutiveRegs, + //! Failed to allocate consecutive registers - allocable registers either too restricted or a bug in RW info. + kErrorConsecutiveRegsAllocation, + + //! Illegal virtual register - reported by instruction validation. + kErrorIllegalVirtReg, + //! AsmJit cannot create more virtual registers. + kErrorTooManyVirtRegs, + + //! AsmJit requires a physical register, but no one is available. + kErrorNoMorePhysRegs, + //! A variable has been assigned more than once to a function argument (BaseCompiler). + kErrorOverlappedRegs, + //! Invalid register to hold stack arguments offset. + kErrorOverlappingStackRegWithRegArg, + + //! Unbound label cannot be evaluated by expression. + kErrorExpressionLabelNotBound, + //! Arithmetic overflow during expression evaluation. + kErrorExpressionOverflow, + + //! Failed to open anonymous memory handle or file descriptor. + kErrorFailedToOpenAnonymousMemory, + + //! Failed to open a file. + //! + //! \note This is a generic error that is used by internal filesystem API. + kErrorFailedToOpenFile, + + //! Protection failure can be returned from a virtual memory allocator or when trying to change memory access + //! permissions. + kErrorProtectionFailure, + + // @EnumValuesEnd@ + + //! Count of AsmJit error codes. + kErrorCount +}; + +//! Debugging utilities. +namespace DebugUtils { + +//! \cond INTERNAL +//! Used to silence warnings about unused arguments or variables. +template +static ASMJIT_INLINE_NODEBUG void unused(Args&&...) noexcept {} +//! \endcond + +//! Returns the error `err` passed. +//! +//! Provided for debugging purposes. Putting a breakpoint inside `errored` can help with tracing the origin of any +//! error reported / returned by AsmJit. +static constexpr Error errored(Error err) noexcept { return err; } + +//! Returns a printable version of `asmjit::Error` code. +ASMJIT_API const char* errorAsString(Error err) noexcept; + +//! Called to output debugging message(s). +ASMJIT_API void debugOutput(const char* str) noexcept; + +//! Called on assertion failure. +//! +//! \param file Source file name where it happened. +//! \param line Line in the source file. +//! \param msg Message to display. +//! +//! If you have problems with assertion failures a breakpoint can be put at \ref assertionFailed() function +//! (asmjit/core/globals.cpp). A call stack will be available when such assertion failure is triggered. AsmJit +//! always returns errors on failures, assertions are a last resort and usually mean unrecoverable state due to out +//! of range array access or totally invalid arguments like nullptr where a valid pointer should be provided, etc... +ASMJIT_API void ASMJIT_NORETURN assertionFailed(const char* file, int line, const char* msg) noexcept; + +} // {DebugUtils} + +//! \def ASMJIT_ASSERT(...) +//! +//! AsmJit's own assert macro used in AsmJit code-base. +#if defined(ASMJIT_BUILD_DEBUG) +#define ASMJIT_ASSERT(...) \ + do { \ + if (ASMJIT_LIKELY(__VA_ARGS__)) \ + break; \ + ::asmjit::DebugUtils::assertionFailed(__FILE__, __LINE__, #__VA_ARGS__); \ + } while (0) +#else +#define ASMJIT_ASSERT(...) ((void)0) +#endif + +//! \def ASMJIT_PROPAGATE(...) +//! +//! Propagates a possible `Error` produced by `...` to the caller by returning the error immediately. Used by AsmJit +//! internally, but kept public for users that want to use the same technique to propagate errors to the caller. +#define ASMJIT_PROPAGATE(...) \ + do { \ + ::asmjit::Error _err = __VA_ARGS__; \ + if (ASMJIT_UNLIKELY(_err)) \ + return _err; \ + } while (0) + +//! \} + +ASMJIT_END_NAMESPACE + +//! Implementation of a placement new so we don't have to depend on ``. +ASMJIT_INLINE_NODEBUG void* operator new(size_t, const asmjit::Support::PlacementNew& p) noexcept { +#if defined(_MSC_VER) && !defined(__clang__) + __assume(p.ptr != nullptr); // Otherwise MSVC would emit a nullptr check. +#endif + return p.ptr; +} + +ASMJIT_INLINE_NODEBUG void operator delete(void*, const asmjit::Support::PlacementNew&) noexcept {} + +#endif // ASMJIT_CORE_GLOBALS_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/inst.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/inst.h new file mode 100644 index 0000000000000000000000000000000000000000..a653fe8e4ccfc98a78b2e264ae0bf67bfbf466a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/inst.h @@ -0,0 +1,804 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_INST_H_INCLUDED +#define ASMJIT_CORE_INST_H_INCLUDED + +#include "../core/cpuinfo.h" +#include "../core/operand.h" +#include "../core/string.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_instruction_db +//! \{ + +//! Describes an instruction id and modifiers used together with the id. +//! +//! Each architecture has a set of valid instructions indexed from 0. Instruction with 0 id is, however, a special +//! instruction that describes a "no instruction" or "invalid instruction". Different architectures can assign a. +//! different instruction to the same id, each architecture typically has its own instructions indexed from 1. +//! +//! Instruction identifiers listed by architecture: +//! +//! - \ref x86::Inst (X86 and X86_64) +//! - \ref a64::Inst (AArch64) +typedef uint32_t InstId; + +//! Instruction id parts. +//! +//! A mask that specifies a bit-layout of \ref InstId. +enum class InstIdParts : uint32_t { + // Common Masks + // ------------ + + //! Real id without any modifiers (always 16 least significant bits). + kRealId = 0x0000FFFFu, + //! Instruction is abstract (or virtual, IR, etc...). + kAbstract = 0x80000000u, + + // ARM Specific + // ------------ + + //! AArch32 first data type, used by ASIMD instructions (`inst.dt.dt2`). + kA32_DT = 0x000F0000u, + //! AArch32 second data type, used by ASIMD instructions (`inst.dt.dt2`). + kA32_DT2 = 0x00F00000u, + //! AArch32/AArch64 condition code. + kARM_Cond = 0x78000000u +}; + +//! Instruction options. +//! +//! Instruction options complement instruction identifier and attributes. +enum class InstOptions : uint32_t { + //! No options. + kNone = 0, + + //! Used internally by emitters for handling errors and rare cases. + kReserved = 0x00000001u, + + //! Prevents following a jump during compilation (Compiler). + kUnfollow = 0x00000002u, + + //! Overwrite the destination operand(s) (Compiler). + //! + //! Hint that is important for register liveness analysis. It tells the compiler that the destination operand will + //! be overwritten now or by adjacent instructions. Compiler knows when a register is completely overwritten by a + //! single instruction, for example you don't have to mark "movaps" or "pxor x, x", however, if a pair of + //! instructions is used and the first of them doesn't completely overwrite the content of the destination, + //! Compiler fails to mark that register as dead. + //! + //! X86 Specific + //! ------------ + //! + //! - All instructions that always overwrite at least the size of the register the virtual-register uses, for + //! example "mov", "movq", "movaps" don't need the overwrite option to be used - conversion, shuffle, and + //! other miscellaneous instructions included. + //! + //! - All instructions that clear the destination register if all operands are the same, for example "xor x, x", + //! "pcmpeqb x x", etc... + //! + //! - Consecutive instructions that partially overwrite the variable until there is no old content require + //! `BaseCompiler::overwrite()` to be used. Some examples (not always the best use cases thought): + //! + //! - `movlps xmm0, ?` followed by `movhps xmm0, ?` and vice versa + //! - `movlpd xmm0, ?` followed by `movhpd xmm0, ?` and vice versa + //! - `mov al, ?` followed by `and ax, 0xFF` + //! - `mov al, ?` followed by `mov ah, al` + //! - `pinsrq xmm0, ?, 0` followed by `pinsrq xmm0, ?, 1` + //! + //! - If the allocated virtual register is used temporarily for scalar operations. For example if you allocate a + //! full vector like `x86::Compiler::newXmm()` and then use that vector for scalar operations you should use + //! `overwrite()` directive: + //! + //! - `sqrtss x, y` - only LO element of `x` is changed, if you don't + //! use HI elements, use `compiler.overwrite().sqrtss(x, y)`. + kOverwrite = 0x00000004u, + + //! Emit short-form of the instruction. + kShortForm = 0x00000010u, + //! Emit long-form of the instruction. + kLongForm = 0x00000020u, + + //! Conditional jump is likely to be taken. + kTaken = 0x00000040u, + //! Conditional jump is unlikely to be taken. + kNotTaken = 0x00000080u, + + // X86 & X64 Options + // ----------------- + + //! Use ModMR instead of ModRM if applicable. + kX86_ModMR = 0x00000100u, + //! Use ModRM instead of ModMR if applicable. + kX86_ModRM = 0x00000200u, + //! Use 3-byte VEX prefix if possible (AVX) (must be 0x00000400). + kX86_Vex3 = 0x00000400u, + //! Use VEX prefix when both VEX|EVEX prefixes are available (HINT: AVX_VNNI). + kX86_Vex = 0x00000800u, + //! Use 4-byte EVEX prefix if possible (AVX-512) (must be 0x00001000). + kX86_Evex = 0x00001000u, + + //! LOCK prefix (lock-enabled instructions only). + kX86_Lock = 0x00002000u, + //! REP prefix (string instructions only). + kX86_Rep = 0x00004000u, + //! REPNE prefix (string instructions only). + kX86_Repne = 0x00008000u, + + //! XACQUIRE prefix (only allowed instructions). + kX86_XAcquire = 0x00010000u, + //! XRELEASE prefix (only allowed instructions). + kX86_XRelease = 0x00020000u, + + //! AVX-512: embedded-rounding {er} and implicit {sae}. + kX86_ER = 0x00040000u, + //! AVX-512: suppress-all-exceptions {sae}. + kX86_SAE = 0x00080000u, + //! AVX-512: round-to-nearest (even) {rn-sae} (bits 00). + kX86_RN_SAE = 0x00000000u, + //! AVX-512: round-down (toward -inf) {rd-sae} (bits 01). + kX86_RD_SAE = 0x00200000u, + //! AVX-512: round-up (toward +inf) {ru-sae} (bits 10). + kX86_RU_SAE = 0x00400000u, + //! AVX-512: round-toward-zero (truncate) {rz-sae} (bits 11). + kX86_RZ_SAE = 0x00600000u, + //! AVX-512: Use zeroing {k}{z} instead of merging {k}. + kX86_ZMask = 0x00800000u, + + //! AVX-512: Mask to get embedded rounding bits (2 bits). + kX86_ERMask = kX86_RZ_SAE, + //! AVX-512: Mask of all possible AVX-512 options except EVEX prefix flag. + kX86_AVX512Mask = 0x00FC0000u, + + //! Force REX.B and/or VEX.B field (X64 only). + kX86_OpCodeB = 0x01000000u, + //! Force REX.X and/or VEX.X field (X64 only). + kX86_OpCodeX = 0x02000000u, + //! Force REX.R and/or VEX.R field (X64 only). + kX86_OpCodeR = 0x04000000u, + //! Force REX.W and/or VEX.W field (X64 only). + kX86_OpCodeW = 0x08000000u, + //! Force REX prefix (X64 only). + kX86_Rex = 0x40000000u, + //! Invalid REX prefix (set by X86 or when AH|BH|CH|DH regs are used on X64). + kX86_InvalidRex = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstOptions) + +//! Instruction control flow. +enum class InstControlFlow : uint32_t { + //! Regular instruction. + kRegular = 0u, + //! Unconditional jump. + kJump = 1u, + //! Conditional jump (branch). + kBranch = 2u, + //! Function call. + kCall = 3u, + //! Function return. + kReturn = 4u, + + //! Maximum value of `InstType`. + kMaxValue = kReturn +}; + +//! Hint that is used when both input operands to the instruction are the same. +//! +//! Provides hints to the instruction RW query regarding special cases in which two or more operands are the same +//! registers. This is required by instructions such as XOR, AND, OR, SUB, etc... These hints will influence the +//! RW operations query. +enum class InstSameRegHint : uint8_t { + //! No special handling. + kNone = 0, + //! Operands become read-only, the operation doesn't change the content - `X & X` and similar. + kRO = 1, + //! Operands become write-only, the content of the input(s) don't matter - `X ^ X`, `X - X`, and similar. + kWO = 2 +}; + +//! Instruction id, options, and extraReg in a single structure. This structure exists mainly to simplify analysis +//! and validation API that requires `BaseInst` and `Operand[]` array. +class BaseInst { +public: + //! \name Members + //! \{ + + //! Instruction id with modifiers. + InstId _id; + //! Instruction options. + InstOptions _options; + //! Extra register used by the instruction (either REP register or AVX-512 selector). + RegOnly _extraReg; + + enum Id : uint32_t { + //! Invalid or uninitialized instruction id. + kIdNone = 0x00000000u, + //! Abstract instruction (BaseBuilder and BaseCompiler). + kIdAbstract = 0x80000000u + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new BaseInst instance with `id` and `options` set. + //! + //! Default values of `id` and `options` are zero, which means 'none' instruction. Such instruction is guaranteed + //! to never exist for any architecture supported by AsmJit. + ASMJIT_INLINE_NODEBUG explicit BaseInst(InstId instId = 0, InstOptions options = InstOptions::kNone) noexcept + : _id(instId), + _options(options), + _extraReg() {} + + ASMJIT_INLINE_NODEBUG BaseInst(InstId instId, InstOptions options, const RegOnly& extraReg) noexcept + : _id(instId), + _options(options), + _extraReg(extraReg) {} + + ASMJIT_INLINE_NODEBUG BaseInst(InstId instId, InstOptions options, const BaseReg& extraReg) noexcept + : _id(instId), + _options(options), + _extraReg { extraReg.signature(), extraReg.id() } {} + + //! \} + + //! \name Instruction id and modifiers + //! \{ + + //! Returns the instruction id with modifiers. + ASMJIT_INLINE_NODEBUG InstId id() const noexcept { return _id; } + //! Sets the instruction id and modiiers from `id`. + ASMJIT_INLINE_NODEBUG void setId(InstId id) noexcept { _id = id; } + //! Resets the instruction id and modifiers to zero, see \ref kIdNone. + ASMJIT_INLINE_NODEBUG void resetId() noexcept { _id = 0; } + + //! Returns a real instruction id that doesn't contain any modifiers. + ASMJIT_INLINE_NODEBUG InstId realId() const noexcept { return _id & uint32_t(InstIdParts::kRealId); } + + template + ASMJIT_INLINE_NODEBUG uint32_t getInstIdPart() const noexcept { + return (uint32_t(_id) & uint32_t(kPart)) >> Support::ConstCTZ::value; + } + + template + ASMJIT_INLINE_NODEBUG void setInstIdPart(uint32_t value) noexcept { + _id = (_id & ~uint32_t(kPart)) | (value << Support::ConstCTZ::value); + } + + //! \} + + //! \name Instruction Options + //! \{ + + ASMJIT_INLINE_NODEBUG InstOptions options() const noexcept { return _options; } + ASMJIT_INLINE_NODEBUG bool hasOption(InstOptions option) const noexcept { return Support::test(_options, option); } + ASMJIT_INLINE_NODEBUG void setOptions(InstOptions options) noexcept { _options = options; } + ASMJIT_INLINE_NODEBUG void addOptions(InstOptions options) noexcept { _options |= options; } + ASMJIT_INLINE_NODEBUG void clearOptions(InstOptions options) noexcept { _options &= ~options; } + ASMJIT_INLINE_NODEBUG void resetOptions() noexcept { _options = InstOptions::kNone; } + + //! \} + + //! \name Extra Register + //! \{ + + ASMJIT_INLINE_NODEBUG bool hasExtraReg() const noexcept { return _extraReg.isReg(); } + ASMJIT_INLINE_NODEBUG RegOnly& extraReg() noexcept { return _extraReg; } + ASMJIT_INLINE_NODEBUG const RegOnly& extraReg() const noexcept { return _extraReg; } + ASMJIT_INLINE_NODEBUG void setExtraReg(const BaseReg& reg) noexcept { _extraReg.init(reg); } + ASMJIT_INLINE_NODEBUG void setExtraReg(const RegOnly& reg) noexcept { _extraReg.init(reg); } + ASMJIT_INLINE_NODEBUG void resetExtraReg() noexcept { _extraReg.reset(); } + + //! \} + + //! \name ARM Specific + //! \{ + + ASMJIT_INLINE_NODEBUG arm::CondCode armCondCode() const noexcept { return (arm::CondCode)getInstIdPart(); } + ASMJIT_INLINE_NODEBUG void setArmCondCode(arm::CondCode cc) noexcept { setInstIdPart(uint32_t(cc)); } + + ASMJIT_INLINE_NODEBUG a32::DataType armDt() const noexcept { return (a32::DataType)getInstIdPart(); } + ASMJIT_INLINE_NODEBUG a32::DataType armDt2() const noexcept { return (a32::DataType)getInstIdPart(); } + + //! \} + + //! \name Statics + //! \{ + + static ASMJIT_INLINE_NODEBUG constexpr InstId composeARMInstId(uint32_t id, arm::CondCode cc) noexcept { + return id | (uint32_t(cc) << Support::ConstCTZ::value); + } + + static ASMJIT_INLINE_NODEBUG constexpr InstId composeARMInstId(uint32_t id, a32::DataType dt, arm::CondCode cc = arm::CondCode::kAL) noexcept { + return id | (uint32_t(dt) << Support::ConstCTZ::value) + | (uint32_t(cc) << Support::ConstCTZ::value); + } + + static ASMJIT_INLINE_NODEBUG constexpr InstId composeARMInstId(uint32_t id, a32::DataType dt, a32::DataType dt2, arm::CondCode cc = arm::CondCode::kAL) noexcept { + return id | (uint32_t(dt) << Support::ConstCTZ::value) + | (uint32_t(dt2) << Support::ConstCTZ::value) + | (uint32_t(cc) << Support::ConstCTZ::value); + } + + static ASMJIT_INLINE_NODEBUG constexpr InstId extractRealId(uint32_t id) noexcept { + return id & uint32_t(InstIdParts::kRealId); + } + + static ASMJIT_INLINE_NODEBUG constexpr arm::CondCode extractARMCondCode(uint32_t id) noexcept { + return (arm::CondCode)((uint32_t(id) & uint32_t(InstIdParts::kARM_Cond)) >> Support::ConstCTZ::value); + } + + //! \} +}; + +//! CPU read/write flags used by \ref InstRWInfo. +//! +//! These flags can be used to get a basic overview about CPU specifics flags used by instructions. +enum class CpuRWFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + + // Common RW Flags (0x000000FF) + // ---------------------------- + + //! Signed overflow flag. + kOF = 0x00000001u, + //! Carry flag. + kCF = 0x00000002u, + //! Zero and/or equality flag (1 if zero/equal). + kZF = 0x00000004u, + //! Sign flag (negative/sign, if set). + kSF = 0x00000008u, + + // X86 Specific RW Flags + // ---------------------------------- + + //! Carry flag (X86, X86_64). + kX86_CF = kCF, + //! Overflow flag (X86, X86_64). + kX86_OF = kOF, + //! Sign flag (X86, X86_64). + kX86_SF = kSF, + //! Zero flag (X86, X86_64). + kX86_ZF = kZF, + + //! Adjust flag (X86, X86_64). + kX86_AF = 0x00000100u, + //! Parity flag (X86, X86_64). + kX86_PF = 0x00000200u, + //! Direction flag (X86, X86_64). + kX86_DF = 0x00000400u, + //! Interrupt enable flag (X86, X86_64). + kX86_IF = 0x00000800u, + + //! Alignment check flag (X86, X86_64). + kX86_AC = 0x00001000u, + + //! FPU C0 status flag (X86, X86_64). + kX86_C0 = 0x00010000u, + //! FPU C1 status flag (X86, X86_64). + kX86_C1 = 0x00020000u, + //! FPU C2 status flag (X86, X86_64). + kX86_C2 = 0x00040000u, + //! FPU C3 status flag (X86, X86_64). + kX86_C3 = 0x00080000u, + + // ARM Specific RW Flags + // ---------------------------------- + + kARM_V = kOF, + kARM_C = kCF, + kARM_Z = kZF, + kARM_N = kSF, + kARM_Q = 0x00000100u, + kARM_GE = 0x00000200u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CpuRWFlags) + +//! Operand read/write flags describe how the operand is accessed and some additional features. +enum class OpRWFlags : uint32_t { + //! No flags. + kNone = 0, + + //! Operand is read. + kRead = 0x00000001u, + + //! Operand is written. + kWrite = 0x00000002u, + + //! Operand is both read and written. + kRW = 0x00000003u, + + //! Register operand can be replaced by a memory operand. + kRegMem = 0x00000004u, + + //! The register must be allocated to the index of the previous register + 1. + //! + //! This flag is used by all architectures to describe instructions that use consecutive registers, where only the + //! first one is encoded in the instruction, and the others are just a sequence that starts with the first one. On + //! X86/X86_64 architecture this is used by instructions such as V4FMADDPS, V4FMADDSS, V4FNMADDPS, V4FNMADDSS, + //! VP4DPWSSD, VP4DPWSSDS, VP2INTERSECTD, and VP2INTERSECTQ. On ARM/AArch64 this is used by vector load and store + //! instructions that can load or store multiple registers at once. + kConsecutive = 0x00000008u, + + //! The `extendByteMask()` represents a zero extension. + kZExt = 0x00000010u, + + //! The register must have assigned a unique physical ID, which cannot be assigned to any other register. + kUnique = 0x00000080u, + + //! Register operand must use \ref OpRWInfo::physId(). + kRegPhysId = 0x00000100u, + //! Base register of a memory operand must use \ref OpRWInfo::physId(). + kMemPhysId = 0x00000200u, + + //! This memory operand is only used to encode registers and doesn't access memory. + //! + //! X86 Specific + //! ------------ + //! + //! Instructions that use such feature include BNDLDX, BNDSTX, and LEA. + kMemFake = 0x000000400u, + + //! Base register of the memory operand will be read. + kMemBaseRead = 0x00001000u, + //! Base register of the memory operand will be written. + kMemBaseWrite = 0x00002000u, + //! Base register of the memory operand will be read & written. + kMemBaseRW = 0x00003000u, + + //! Index register of the memory operand will be read. + kMemIndexRead = 0x00004000u, + //! Index register of the memory operand will be written. + kMemIndexWrite = 0x00008000u, + //! Index register of the memory operand will be read & written. + kMemIndexRW = 0x0000C000u, + + //! Base register of the memory operand will be modified before the operation. + kMemBasePreModify = 0x00010000u, + //! Base register of the memory operand will be modified after the operation. + kMemBasePostModify = 0x00020000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(OpRWFlags) + +// Don't remove these asserts. Read/Write flags are used extensively +// by Compiler and they must always be compatible with constants below. +static_assert(uint32_t(OpRWFlags::kRead) == 0x1, "OpRWFlags::kRead flag must be 0x1"); +static_assert(uint32_t(OpRWFlags::kWrite) == 0x2, "OpRWFlags::kWrite flag must be 0x2"); +static_assert(uint32_t(OpRWFlags::kRegMem) == 0x4, "OpRWFlags::kRegMem flag must be 0x4"); + +//! Read/Write information related to a single operand, used by \ref InstRWInfo. +struct OpRWInfo { + //! \name Members + //! \{ + + //! Read/Write flags. + OpRWFlags _opFlags; + //! Physical register index, if required. + uint8_t _physId; + //! Size of a possible memory operand that can replace a register operand. + uint8_t _rmSize; + //! If non-zero, then this is a consecutive lead register, and the value describes how many registers follow. + uint8_t _consecutiveLeadCount; + //! Reserved for future use. + uint8_t _reserved[1]; + //! Read bit-mask where each bit represents one byte read from Reg/Mem. + uint64_t _readByteMask; + //! Write bit-mask where each bit represents one byte written to Reg/Mem. + uint64_t _writeByteMask; + //! Zero/Sign extend bit-mask where each bit represents one byte written to Reg/Mem. + uint64_t _extendByteMask; + + //! \} + + //! \name Reset + //! \{ + + //! Resets this operand information to all zeros. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = OpRWInfo{}; } + + //! Resets this operand info (resets all members) and set common information + //! to the given `opFlags`, `regSize`, and possibly `physId`. + inline void reset(OpRWFlags opFlags, uint32_t regSize, uint32_t physId = BaseReg::kIdBad) noexcept { + _opFlags = opFlags; + _physId = uint8_t(physId); + _rmSize = Support::test(opFlags, OpRWFlags::kRegMem) ? uint8_t(regSize) : uint8_t(0); + _consecutiveLeadCount = 0; + _resetReserved(); + + uint64_t mask = Support::lsbMask(Support::min(regSize, 64)); + + _readByteMask = Support::test(opFlags, OpRWFlags::kRead) ? mask : uint64_t(0); + _writeByteMask = Support::test(opFlags, OpRWFlags::kWrite) ? mask : uint64_t(0); + _extendByteMask = 0; + } + + ASMJIT_INLINE_NODEBUG void _resetReserved() noexcept { + _reserved[0] = 0; + } + + //! \} + + //! \name Operand Flags + //! \{ + + //! Returns operand flags. + ASMJIT_INLINE_NODEBUG OpRWFlags opFlags() const noexcept { return _opFlags; } + //! Tests whether operand flags contain the given `flag`. + ASMJIT_INLINE_NODEBUG bool hasOpFlag(OpRWFlags flag) const noexcept { return Support::test(_opFlags, flag); } + + //! Adds the given `flags` to operand flags. + ASMJIT_INLINE_NODEBUG void addOpFlags(OpRWFlags flags) noexcept { _opFlags |= flags; } + //! Removes the given `flags` from operand flags. + ASMJIT_INLINE_NODEBUG void clearOpFlags(OpRWFlags flags) noexcept { _opFlags &= ~flags; } + + //! Tests whether this operand is read from. + ASMJIT_INLINE_NODEBUG bool isRead() const noexcept { return hasOpFlag(OpRWFlags::kRead); } + //! Tests whether this operand is written to. + ASMJIT_INLINE_NODEBUG bool isWrite() const noexcept { return hasOpFlag(OpRWFlags::kWrite); } + //! Tests whether this operand is both read and write. + ASMJIT_INLINE_NODEBUG bool isReadWrite() const noexcept { return (_opFlags & OpRWFlags::kRW) == OpRWFlags::kRW; } + //! Tests whether this operand is read only. + ASMJIT_INLINE_NODEBUG bool isReadOnly() const noexcept { return (_opFlags & OpRWFlags::kRW) == OpRWFlags::kRead; } + //! Tests whether this operand is write only. + ASMJIT_INLINE_NODEBUG bool isWriteOnly() const noexcept { return (_opFlags & OpRWFlags::kRW) == OpRWFlags::kWrite; } + + //! Returns the type of a lead register, which is followed by consecutive registers. + ASMJIT_INLINE_NODEBUG uint32_t consecutiveLeadCount() const noexcept { return _consecutiveLeadCount; } + + //! Tests whether this operand is Reg/Mem + //! + //! Reg/Mem operands can use either register or memory. + ASMJIT_INLINE_NODEBUG bool isRm() const noexcept { return hasOpFlag(OpRWFlags::kRegMem); } + + //! Tests whether the operand will be zero extended. + ASMJIT_INLINE_NODEBUG bool isZExt() const noexcept { return hasOpFlag(OpRWFlags::kZExt); } + + //! Tests whether the operand must have allocated a unique physical id that cannot be shared with other register + //! operands. + ASMJIT_INLINE_NODEBUG bool isUnique() const noexcept { return hasOpFlag(OpRWFlags::kUnique); } + + //! \} + + //! \name Memory Flags + //! \{ + + //! Tests whether this is a fake memory operand, which is only used, because of encoding. Fake memory operands do + //! not access any memory, they are only used to encode registers. + ASMJIT_INLINE_NODEBUG bool isMemFake() const noexcept { return hasOpFlag(OpRWFlags::kMemFake); } + + //! Tests whether the instruction's memory BASE register is used. + ASMJIT_INLINE_NODEBUG bool isMemBaseUsed() const noexcept { return hasOpFlag(OpRWFlags::kMemBaseRW); } + //! Tests whether the instruction reads from its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseRead() const noexcept { return hasOpFlag(OpRWFlags::kMemBaseRead); } + //! Tests whether the instruction writes to its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseWrite() const noexcept { return hasOpFlag(OpRWFlags::kMemBaseWrite); } + //! Tests whether the instruction reads and writes from/to its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseReadWrite() const noexcept { return (_opFlags & OpRWFlags::kMemBaseRW) == OpRWFlags::kMemBaseRW; } + //! Tests whether the instruction only reads from its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseReadOnly() const noexcept { return (_opFlags & OpRWFlags::kMemBaseRW) == OpRWFlags::kMemBaseRead; } + //! Tests whether the instruction only writes to its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseWriteOnly() const noexcept { return (_opFlags & OpRWFlags::kMemBaseRW) == OpRWFlags::kMemBaseWrite; } + + //! Tests whether the instruction modifies the BASE register before it uses it to calculate the target address. + ASMJIT_INLINE_NODEBUG bool isMemBasePreModify() const noexcept { return hasOpFlag(OpRWFlags::kMemBasePreModify); } + //! Tests whether the instruction modifies the BASE register after it uses it to calculate the target address. + ASMJIT_INLINE_NODEBUG bool isMemBasePostModify() const noexcept { return hasOpFlag(OpRWFlags::kMemBasePostModify); } + + //! Tests whether the instruction's memory INDEX register is used. + ASMJIT_INLINE_NODEBUG bool isMemIndexUsed() const noexcept { return hasOpFlag(OpRWFlags::kMemIndexRW); } + //! Tests whether the instruction reads the INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexRead() const noexcept { return hasOpFlag(OpRWFlags::kMemIndexRead); } + //! Tests whether the instruction writes to its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexWrite() const noexcept { return hasOpFlag(OpRWFlags::kMemIndexWrite); } + //! Tests whether the instruction reads and writes from/to its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexReadWrite() const noexcept { return (_opFlags & OpRWFlags::kMemIndexRW) == OpRWFlags::kMemIndexRW; } + //! Tests whether the instruction only reads from its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexReadOnly() const noexcept { return (_opFlags & OpRWFlags::kMemIndexRW) == OpRWFlags::kMemIndexRead; } + //! Tests whether the instruction only writes to its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexWriteOnly() const noexcept { return (_opFlags & OpRWFlags::kMemIndexRW) == OpRWFlags::kMemIndexWrite; } + + //! \} + + //! \name Physical Register ID + //! \{ + + //! Returns a physical id of the register that is fixed for this operand. + //! + //! Returns \ref BaseReg::kIdBad if any register can be used. + ASMJIT_INLINE_NODEBUG uint32_t physId() const noexcept { return _physId; } + //! Tests whether \ref physId() would return a valid physical register id. + ASMJIT_INLINE_NODEBUG bool hasPhysId() const noexcept { return _physId != BaseReg::kIdBad; } + //! Sets physical register id, which would be fixed for this operand. + ASMJIT_INLINE_NODEBUG void setPhysId(uint32_t physId) noexcept { _physId = uint8_t(physId); } + + //! \} + + //! \name Reg/Mem Information + //! \{ + + //! Returns Reg/Mem size of the operand. + ASMJIT_INLINE_NODEBUG uint32_t rmSize() const noexcept { return _rmSize; } + //! Sets Reg/Mem size of the operand. + ASMJIT_INLINE_NODEBUG void setRmSize(uint32_t rmSize) noexcept { _rmSize = uint8_t(rmSize); } + + //! \} + + //! \name Read & Write Masks + //! \{ + + //! Returns read mask. + ASMJIT_INLINE_NODEBUG uint64_t readByteMask() const noexcept { return _readByteMask; } + //! Returns write mask. + ASMJIT_INLINE_NODEBUG uint64_t writeByteMask() const noexcept { return _writeByteMask; } + //! Returns extend mask. + ASMJIT_INLINE_NODEBUG uint64_t extendByteMask() const noexcept { return _extendByteMask; } + + //! Sets read mask. + ASMJIT_INLINE_NODEBUG void setReadByteMask(uint64_t mask) noexcept { _readByteMask = mask; } + //! Sets write mask. + ASMJIT_INLINE_NODEBUG void setWriteByteMask(uint64_t mask) noexcept { _writeByteMask = mask; } + //! Sets extend mask. + ASMJIT_INLINE_NODEBUG void setExtendByteMask(uint64_t mask) noexcept { _extendByteMask = mask; } + + //! \} +}; + +//! Flags used by \ref InstRWInfo. +enum class InstRWFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + + //! Describes a move operation. + //! + //! This flag is used by RA to eliminate moves that are guaranteed to be moves only. + kMovOp = 0x00000001u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstRWFlags) + +//! Read/Write information of an instruction. +struct InstRWInfo { + //! \name Members + //! \{ + + //! Instruction flags (there are no flags at the moment, this field is reserved). + InstRWFlags _instFlags; + //! CPU flags read. + CpuRWFlags _readFlags; + //! CPU flags written. + CpuRWFlags _writeFlags; + //! Count of operands. + uint8_t _opCount; + //! CPU feature required for replacing register operand with memory operand. + uint8_t _rmFeature; + //! Reserved for future use. + uint8_t _reserved[18]; + //! Read/Write info of extra register (rep{} or kz{}). + OpRWInfo _extraReg; + //! Read/Write info of instruction operands. + OpRWInfo _operands[Globals::kMaxOpCount]; + + //! \} + + //! \name Commons + //! \{ + + //! Resets this RW information to all zeros. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = InstRWInfo{}; } + + //! \} + + //! \name Instruction Flags + //! \{ + + //! Returns flags associated with the instruction, see \ref InstRWFlags. + ASMJIT_INLINE_NODEBUG InstRWFlags instFlags() const noexcept { return _instFlags; } + + //! Tests whether the instruction flags contain `flag`. + ASMJIT_INLINE_NODEBUG bool hasInstFlag(InstRWFlags flag) const noexcept { return Support::test(_instFlags, flag); } + + //! Tests whether the instruction flags contain \ref InstRWFlags::kMovOp. + ASMJIT_INLINE_NODEBUG bool isMovOp() const noexcept { return hasInstFlag(InstRWFlags::kMovOp); } + + //! \} + + //! \name CPU Flags Information + //! \{ + + //! Returns a mask of CPU flags read. + ASMJIT_INLINE_NODEBUG CpuRWFlags readFlags() const noexcept { return _readFlags; } + //! Returns a mask of CPU flags written. + ASMJIT_INLINE_NODEBUG CpuRWFlags writeFlags() const noexcept { return _writeFlags; } + + //! \} + + //! \name Reg/Mem Information + //! \{ + + //! Returns the CPU feature required to replace a register operand with memory operand. If the returned feature is + //! zero (none) then this instruction either doesn't provide memory operand combination or there is no extra CPU + //! feature required. + //! + //! X86 Specific + //! ------------ + //! + //! Some AVX+ instructions may require extra features for replacing registers with memory operands, for example + //! VPSLLDQ instruction only supports `vpslldq reg, reg, imm` combination on AVX/AVX2 capable CPUs and requires + //! AVX-512 for `vpslldq reg, mem, imm` combination. + ASMJIT_INLINE_NODEBUG uint32_t rmFeature() const noexcept { return _rmFeature; } + + //! \} + + //! \name Operand Read/Write Information + //! \{ + + //! Returns RW information of extra register operand (extraReg). + ASMJIT_INLINE_NODEBUG const OpRWInfo& extraReg() const noexcept { return _extraReg; } + + //! Returns RW information of all instruction's operands. + ASMJIT_INLINE_NODEBUG const OpRWInfo* operands() const noexcept { return _operands; } + + //! Returns RW information of the operand at the given `index`. + inline const OpRWInfo& operand(size_t index) const noexcept { + ASMJIT_ASSERT(index < Globals::kMaxOpCount); + return _operands[index]; + } + + //! Returns the number of operands this instruction has. + ASMJIT_INLINE_NODEBUG uint32_t opCount() const noexcept { return _opCount; } + + //! \} +}; + +//! Validation flags that can be used with \ref InstAPI::validate(). +enum class ValidationFlags : uint32_t { + //! No flags. + kNone = 0, + //! Allow virtual registers in the instruction. + kEnableVirtRegs = 0x01u +}; +ASMJIT_DEFINE_ENUM_FLAGS(ValidationFlags) + +//! Instruction API. +namespace InstAPI { + +#ifndef ASMJIT_NO_TEXT +//! Appends the name of the instruction specified by `instId` and `instOptions` into the `output` string. +//! +//! \note Instruction options would only affect instruction prefix & suffix, other options would be ignored. +//! If `instOptions` is zero then only raw instruction name (without any additional text) will be appended. +ASMJIT_API Error instIdToString(Arch arch, InstId instId, String& output) noexcept; + +//! Parses an instruction name in the given string `s`. Length is specified by `len` argument, which can be +//! `SIZE_MAX` if `s` is known to be null terminated. +//! +//! Returns the parsed instruction id or \ref BaseInst::kIdNone if no such instruction exists. +ASMJIT_API InstId stringToInstId(Arch arch, const char* s, size_t len) noexcept; +#endif // !ASMJIT_NO_TEXT + +#ifndef ASMJIT_NO_VALIDATION +//! Validates the given instruction considering the given `validationFlags`. +ASMJIT_API Error validate(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, ValidationFlags validationFlags = ValidationFlags::kNone) noexcept; +#endif // !ASMJIT_NO_VALIDATION + +#ifndef ASMJIT_NO_INTROSPECTION +//! Gets Read/Write information of the given instruction. +ASMJIT_API Error queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, InstRWInfo* out) noexcept; + +//! Gets CPU features required by the given instruction. +ASMJIT_API Error queryFeatures(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, CpuFeatures* out) noexcept; +#endif // !ASMJIT_NO_INTROSPECTION + +} // {InstAPI} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_INST_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/jitallocator.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/jitallocator.h new file mode 100644 index 0000000000000000000000000000000000000000..b694f8cd535da78769fd246640376292de7141b9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/jitallocator.h @@ -0,0 +1,570 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_JITALLOCATOR_H_INCLUDED +#define ASMJIT_CORE_JITALLOCATOR_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_JIT + +#include "../core/globals.h" +#include "../core/support.h" +#include "../core/virtmem.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_virtual_memory +//! \{ + +//! Options used by \ref JitAllocator. +enum class JitAllocatorOptions : uint32_t { + //! No options. + kNone = 0, + + //! Enables the use of an anonymous memory-mapped memory that is mapped into two buffers having a different pointer. + //! The first buffer has read and execute permissions and the second buffer has read+write permissions. + //! + //! See \ref VirtMem::allocDualMapping() for more details about this feature. + //! + //! \remarks Dual mapping would be automatically turned on by \ref JitAllocator in case of hardened runtime that + //! enforces `W^X` policy, so specifying this flag is essentially forcing to use dual mapped pages even when RWX + //! pages can be allocated and dual mapping is not necessary. + kUseDualMapping = 0x00000001u, + + //! Enables the use of multiple pools with increasing granularity instead of a single pool. This flag would enable + //! 3 internal pools in total having 64, 128, and 256 bytes granularity. + //! + //! This feature is only recommended for users that generate a lot of code and would like to minimize the overhead + //! of `JitAllocator` itself by having blocks of different allocation granularities. Using this feature only for + //! few allocations won't pay off as the allocator may need to create more blocks initially before it can take the + //! advantage of variable block granularity. + kUseMultiplePools = 0x00000002u, + + //! Always fill reserved memory by a fill-pattern. + //! + //! Causes a new block to be cleared by the fill pattern and freshly released memory to be cleared before making + //! it ready for another use. + kFillUnusedMemory = 0x00000004u, + + //! When this flag is set the allocator would immediately release unused blocks during `release()` or `reset()`. + //! When this flag is not set the allocator would keep one empty block in each pool to prevent excessive virtual + //! memory allocations and deallocations in border cases, which involve constantly allocating and deallocating a + //! single block caused by repetitive calling `alloc()` and `release()` when the allocator has either no blocks + //! or have all blocks fully occupied. + kImmediateRelease = 0x00000008u, + + //! This flag enables placing functions (or allocating memory) at the very beginning of each memory mapped region. + //! + //! Initially, this was the default behavior. However, LLVM developers working on undefined behavior sanitizer + //! (UBSAN) decided that they want to store metadata before each function and to access such metadata before an + //! indirect function call. This means that the instrumented code always reads from `[fnPtr - 8]` to decode whether + //! the function has his metadata present. However, reading 8 bytes below a function means that if a function is + //! placed at the very beginning of a memory mapped region, it could try to read bytes that are inaccessible. And + //! since AsmJit can be compiled as a shared library and used by applications instrumented by UBSAN, it's not + //! possible to conditionally compile the support only when necessary. + //! + //! \remarks This flag controls a workaround to make it possible to use LLVM UBSAN with AsmJit's \ref JitAllocator. + //! There is no undefined behavior even when `kDisableInitialPadding` is used, however, that doesn't really matter + //! as LLVM's UBSAN introduces one, and according to LLVM developers it's a "trade-off". This flag is safe to use + //! when the code is not instrumented with LLVM's UBSAN. + kDisableInitialPadding = 0x00000010u, + + //! Enables the use of large pages, if they are supported and the process can actually allocate them. + //! + //! \remarks This flag is a hint - if large pages can be allocated, JitAllocator would try to allocate them. + //! However, if the allocation fails, it will still try to fallback to use regular pages as \ref JitAllocator + //! is designed to minimize allocation failures, so a regular page is better than no page at all. Also, if a + //! block \ref JitAllocator wants to allocate is too small to consume a whole large page, regular page(s) will + //! be allocated as well. + kUseLargePages = 0x00000020u, + + //! Forces \ref JitAllocator to always align block size to be at least as big as a large page, if large pages are + //! enabled. This option does nothing if large pages are disabled. + //! + //! \remarks If \ref kUseLargePages option is used, the allocator would prefer large pages only when allocating a + //! block that has a sufficient size. Usually the allocator first allocates smaller block and when more requests + //! come it will start increasing the block size of next allocations. This option makes it sure that even the first + //! allocation would be the same as a minimum large page when large pages are enabled and can be allocated. + kAlignBlockSizeToLargePage = 0x00000040u, + + //! Use a custom fill pattern, must be combined with `kFlagFillUnusedMemory`. + kCustomFillPattern = 0x10000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(JitAllocatorOptions) + +//! A simple implementation of memory manager that uses `asmjit::VirtMem` +//! functions to manage virtual memory for JIT compiled code. +//! +//! Implementation notes: +//! +//! - Granularity of allocated blocks is different than granularity for a typical C malloc. In addition, the allocator +//! can use several memory pools having a different granularity to minimize the maintenance overhead. Multiple pools +//! feature requires `kFlagUseMultiplePools` flag to be set. +//! +//! - The allocator doesn't store any information in executable memory, instead, the implementation uses two +//! bit-vectors to manage allocated memory of each allocator-block. The first bit-vector called 'used' is used to +//! track used memory (where each bit represents memory size defined by granularity) and the second bit vector called +//! 'stop' is used as a sentinel to mark where the allocated area ends. +//! +//! - Internally, the allocator also uses RB tree to keep track of all blocks across all pools. Each inserted block is +//! added to the tree so it can be matched fast during `release()` and `shrink()`. +class JitAllocator { +public: + ASMJIT_NONCOPYABLE(JitAllocator) + + //! Visible \ref JitAllocator implementation data. + struct Impl { + //! Allocator options. + JitAllocatorOptions options; + //! Base block size (0 if the allocator is not initialized). + uint32_t blockSize; + //! Base granularity (0 if the allocator is not initialized). + uint32_t granularity; + //! A pattern that is used to fill unused memory if secure mode is enabled. + uint32_t fillPattern; + }; + + //! \name Members + //! \{ + + //! Allocator implementation (private). + Impl* _impl; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Parameters that can be passed to `JitAllocator` constructor. + //! + //! Use it like this: + //! + //! ``` + //! // Zero initialize (zero means the default value) and change what you need. + //! JitAllocator::CreateParams params {}; + //! params.blockSize = 1024 * 1024; + //! + //! // Create the allocator. + //! JitAllocator allocator(¶ms); + //! ``` + struct CreateParams { + //! Allocator options. + //! + //! No options are used by default. + JitAllocatorOptions options = JitAllocatorOptions::kNone; + + //! Base size of a single block in bytes (default 64kB). + //! + //! \remarks Block size must be equal to or greater than page size and must be power of 2. If the input is not + //! valid then the default block size will be used instead. + uint32_t blockSize = 0; + + //! Base granularity (and also natural alignment) of allocations in bytes (default 64). + //! + //! Since the `JitAllocator` uses bit-arrays to mark used memory the granularity also specifies how many bytes + //! correspond to a single bit in such bit-array. Higher granularity means more waste of virtual memory (as it + //! increases the natural alignment), but smaller bit-arrays as less bits would be required per a single block. + uint32_t granularity = 0; + + //! Patter to use to fill unused memory. + //! + //! Only used if \ref JitAllocatorOptions::kCustomFillPattern is set. + uint32_t fillPattern = 0; + + // Reset the content of `CreateParams`. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = CreateParams{}; } + }; + + //! Creates a `JitAllocator` instance. + ASMJIT_API explicit JitAllocator(const CreateParams* params = nullptr) noexcept; + //! Destroys the `JitAllocator` instance and release all blocks held. + ASMJIT_API ~JitAllocator() noexcept; + + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _impl->blockSize == 0; } + + //! Free all allocated memory - makes all pointers returned by `alloc()` invalid. + //! + //! \remarks This function is not thread-safe as it's designed to be used when nobody else is using allocator. + //! The reason is that there is no point of calling `reset()` when the allocator is still in use. + ASMJIT_API void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns allocator options, see `Flags`. + ASMJIT_INLINE_NODEBUG JitAllocatorOptions options() const noexcept { return _impl->options; } + //! Tests whether the allocator has the given `option` set. + ASMJIT_INLINE_NODEBUG bool hasOption(JitAllocatorOptions option) const noexcept { return uint32_t(_impl->options & option) != 0; } + + //! Returns a base block size (a minimum size of block that the allocator would allocate). + ASMJIT_INLINE_NODEBUG uint32_t blockSize() const noexcept { return _impl->blockSize; } + //! Returns granularity of the allocator. + ASMJIT_INLINE_NODEBUG uint32_t granularity() const noexcept { return _impl->granularity; } + //! Returns pattern that is used to fill unused memory if `kFlagUseFillPattern` is set. + ASMJIT_INLINE_NODEBUG uint32_t fillPattern() const noexcept { return _impl->fillPattern; } + + //! \} + + //! \name Alloc & Release + //! \{ + + //! A memory reference returned by \ref JitAllocator::alloc(). + //! + //! Span contains everything needed to actually write new code to the memory chunk it references. + class Span { + public: + //! \name Constants + //! \{ + + //! Span flags + enum class Flags : uint32_t { + //! No flags. + kNone = 0u, + + //! The process has never executed the region of the span. + //! + //! If this flag is set on a \ref Span it would mean that the allocator can avoid flushing + //! instruction cache after a code has been written to it. + kInstructionCacheClean = 0x00000001u + }; + + //! \} + + //! \name Members + //! \{ + + //! Address of memory that has Read and Execute permissions. + void* _rx = nullptr; + + //! Address of memory that has Read and Write permissions. + void* _rw = nullptr; + + //! Size of the span in bytes (rounded up to the allocation granularity). + size_t _size = 0; + + //! Pointer that references a memory block maintained by \ref JitAllocator. + //! + //! This pointer is considered private and should never be used nor inspected outside of AsmJit. + void* _block = nullptr; + + //! Span flags. + Flags _flags = Flags::kNone; + + //! Reserved for future use. + uint32_t _reserved = 0; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns a pointer having Read & Execute permissions (references executable memory). + //! + //! This pointer is never NULL if the allocation succeeded, it points to an executable memory. + ASMJIT_INLINE_NODEBUG void* rx() const noexcept { return _rx; } + + //! Returns a pointer having Read & Write permissions (references writable memory). + //! + //! Depending on the type of the allocation strategy this could either be: + //! + //! - the same address as returned by `rx()` if the allocator uses RWX mapping (pages have all of Read, Write, + //! and Execute permissions) or MAP_JIT, which requires either \ref VirtMem::ProtectJitReadWriteScope or to + //! call \ref VirtMem::protectJitMemory() manually. + //! - a valid pointer, but not the same as `rx` - this would be valid if dual mapping is used. + //! - NULL pointer, in case that the allocation strategy doesn't use RWX, MAP_JIT, or dual mapping. In this + //! case only \ref JitAllocator can copy new code into the executable memory referenced by \ref Span. + //! + //! \note If `rw()` returns a non-null pointer it's important to use either VirtMem::protectJitMemory() or + //! \ref VirtMem::ProtectJitReadWriteScope to guard the write, because in case of `MAP_JIT` it would temporarily + //! switch the permissions of the pointer to RW (that's per thread permissions). + //! + //! If \ref VirtMem::ProtectJitReadWriteScope is not used it's important to clear the instruction cache via + //! \ref VirtMem::flushInstructionCache() after the write is done. + ASMJIT_INLINE_NODEBUG void* rw() const noexcept { return _rw; } + + //! Returns size of this span, aligned to the allocator granularity. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + + //! Returns span flags. + ASMJIT_INLINE_NODEBUG Flags flags() const noexcept { return _flags; } + + //! Shrinks this span to `newSize`. + //! + //! \note This is the only function that is able to change the size of a span, and it's only use case is to + //! shrink the span size during \ref JitAllocator::write(). When the writer detects that the span size shrunk, + //! it will automatically shrink the memory used by the span, and propagate the new aligned size to the caller. + ASMJIT_INLINE_NODEBUG void shrink(size_t newSize) noexcept { _size = Support::min(_size, newSize); } + + //! Returns whether \ref rw() returns a non-null pointer. + ASMJIT_INLINE_NODEBUG bool isDirectlyWritable() const noexcept { return _rw != nullptr; } + + //! \} + }; + + //! Allocates a new memory span of the requested `size`. + ASMJIT_API Error alloc(Span& out, size_t size) noexcept; + + //! Releases a memory block returned by `alloc()`. + //! + //! \remarks This function is thread-safe. + ASMJIT_API Error release(void* rx) noexcept; + + //! Frees extra memory allocated with `rx` by shrinking it to the given `newSize`. + //! + //! \remarks This function is thread-safe. + ASMJIT_API Error shrink(Span& span, size_t newSize) noexcept; + + //! Queries information about an allocated memory block that contains the given `rx`, and writes it to `out`. + //! + //! If the pointer is matched, the function returns `kErrorOk` and fills `out` with the corresponding span. + ASMJIT_API Error query(Span& out, void* rx) const noexcept; + +#if !defined(ASMJIT_NO_DEPRECATED) + //! Allocates a new memory block of the requested `size`. + ASMJIT_DEPRECATED("Use alloc(Span& out, size_t size) instead") + ASMJIT_FORCE_INLINE Error alloc(void** rxPtrOut, void** rwPtrOut, size_t size) noexcept { + Span span; + Error err = alloc(span, size); + *rwPtrOut = span.rw(); + *rxPtrOut = span.rx(); + return err; + } + + ASMJIT_DEPRECATED("Use shrink(Span& span, size_t newSize) instead") + ASMJIT_FORCE_INLINE Error shrink(void* rxPtr, size_t newSize) noexcept { + Span span; + ASMJIT_PROPAGATE(query(span, rxPtr)); + return (span.size() > newSize) ? shrink(span, newSize) : Error(kErrorOk); + } + + ASMJIT_DEPRECATED("Use query(Span& out, void* rx) instead") + ASMJIT_FORCE_INLINE Error query(void* rxPtr, void** rxPtrOut, void** rwPtrOut, size_t* sizeOut) const noexcept { + Span span; + Error err = query(span, rxPtr); + *rxPtrOut = span.rx(); + *rwPtrOut = span.rw(); + *sizeOut = span.size(); + return err; + } +#endif + + //! \} + + //! \name Write Operations + //! \{ + + typedef Error (ASMJIT_CDECL* WriteFunc)(Span& span, void* userData) ASMJIT_NOEXCEPT_TYPE; + + ASMJIT_API Error write( + Span& span, + size_t offset, + const void* src, + size_t size, + VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept; + + ASMJIT_API Error write( + Span& span, + WriteFunc writeFunc, + void* userData, + VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept; + + template + ASMJIT_FORCE_INLINE Error write( + Span& span, + Lambda&& lambdaFunc, + VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept { + + WriteFunc wrapperFunc = [](Span& span, void* userData) noexcept -> Error { + Lambda& lambdaFunc = *static_cast(userData); + return lambdaFunc(span); + }; + return write(span, wrapperFunc, (void*)(&lambdaFunc), policy); + } + + //! \} + + //! \name Write Operations with Scope + //! \{ + + //! \cond INTERNAL + + //! Write scope data. + //! + //! This is mostly for internal purposes, please use \ref WriteScope instead. + struct WriteScopeData { + //! \name Members + //! \{ + + //! Link to the allocator. + JitAllocator* _allocator; + //! Cache policy passed to \ref JitAllocator::beginWriteScope(). + VirtMem::CachePolicy _policy; + //! Internal flags used by the implementation. + uint32_t _flags; + //! Internal data used by the implementation. + size_t _data[64]; + + //! \} + }; + + //! Begins a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope constructor instead. + ASMJIT_API Error beginWriteScope(WriteScopeData& scope, VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept; + + //! Ends a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope destructor instead. + ASMJIT_API Error endWriteScope(WriteScopeData& scope) noexcept; + + //! Flushes accumulated changes in a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope destructor or \ref WriteScope::flush() instead. + ASMJIT_API Error flushWriteScope(WriteScopeData& scope) noexcept; + + //! Alternative to `JitAllocator::write(span, offset, src, size)`, but under a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope::write() instead. + ASMJIT_API Error scopedWrite(WriteScopeData& scope, Span& span, size_t offset, const void* src, size_t size) noexcept; + + //! Alternative to `JitAllocator::write(span, writeFunc, userData)`, but under a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope::write() instead. + ASMJIT_API Error scopedWrite(WriteScopeData& scope, Span& span, WriteFunc writeFunc, void* userData) noexcept; + + //! Alternative to `JitAllocator::write(span, [lambda])`, but under a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope::write() instead. + template + inline Error scopedWrite(WriteScopeData& scope, Span& span, Lambda&& lambdaFunc) noexcept { + WriteFunc wrapperFunc = [](Span& span, void* userData) noexcept -> Error { + Lambda& lambdaFunc = *static_cast(userData); + return lambdaFunc(span); + }; + return scopedWrite(scope, span, wrapperFunc, (void*)(&lambdaFunc)); + } + + //! \endcond + + //! Write scope can be used to create a single scope that is optimized for writing multiple spans. + class WriteScope : public WriteScopeData { + public: + ASMJIT_NONCOPYABLE(WriteScope) + + //! \name Construction & Destruction + //! \{ + + // Begins a write scope. + inline explicit WriteScope(JitAllocator* allocator, VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept { + allocator->beginWriteScope(*this, policy); + } + + // Ends a write scope. + inline ~WriteScope() noexcept { + if (_allocator) + _allocator->endWriteScope(*this); + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG JitAllocator* allocator() const noexcept { return _allocator; } + ASMJIT_INLINE_NODEBUG VirtMem::CachePolicy policy() const noexcept { return _policy; } + + //! \} + + //! \name Operations + //! \{ + + //! Similar to `JitAllocator::write(span, offset, src, size)`, but under a write scope. + ASMJIT_INLINE_NODEBUG Error write(Span& span, size_t offset, const void* src, size_t size) noexcept { + return _allocator->scopedWrite(*this, span, offset, src, size); + } + + //! Similar to `JitAllocator::write(span, writeFunc, userData)`, but under a write scope. + ASMJIT_INLINE_NODEBUG Error write(Span& span, WriteFunc writeFunc, void* userData) noexcept { + return _allocator->scopedWrite(*this, span, writeFunc, userData); + } + + //! Similar to `JitAllocator::write(span, )`, but under a write scope. + template + ASMJIT_INLINE_NODEBUG Error write(Span& span, Lambda&& lambdaFunc) noexcept { + return _allocator->scopedWrite(*this, span, lambdaFunc); + } + + //! Flushes accumulated changes in this write scope. + ASMJIT_INLINE_NODEBUG Error flush() noexcept { + return _allocator->flushWriteScope(*this); + } + + //! \} + }; + + //! \} + + //! \name Statistics + //! \{ + + //! Statistics about `JitAllocator`. + struct Statistics { + //! Number of blocks `JitAllocator` maintains. + size_t _blockCount; + //! Number of active allocations. + size_t _allocationCount; + //! How many bytes are currently used / allocated. + size_t _usedSize; + //! How many bytes are currently reserved by the allocator. + size_t _reservedSize; + //! Allocation overhead (in bytes) required to maintain all blocks. + size_t _overheadSize; + + //! Resets the statistics to all zeros. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = Statistics{}; } + + //! Returns count of blocks managed by `JitAllocator` at the moment. + ASMJIT_INLINE_NODEBUG size_t blockCount() const noexcept { return _blockCount; } + //! Returns the number of active allocations. + ASMJIT_INLINE_NODEBUG size_t allocationCount() const noexcept { return _allocationCount; } + + //! Returns how many bytes are currently used. + ASMJIT_INLINE_NODEBUG size_t usedSize() const noexcept { return _usedSize; } + //! Returns the number of bytes unused by the allocator at the moment. + ASMJIT_INLINE_NODEBUG size_t unusedSize() const noexcept { return _reservedSize - _usedSize; } + //! Returns the total number of bytes reserved by the allocator (sum of sizes of all blocks). + ASMJIT_INLINE_NODEBUG size_t reservedSize() const noexcept { return _reservedSize; } + //! Returns the number of bytes the allocator needs to manage the allocated memory. + ASMJIT_INLINE_NODEBUG size_t overheadSize() const noexcept { return _overheadSize; } + + ASMJIT_INLINE_NODEBUG double usedSizeAsPercent() const noexcept { + return (double(usedSize()) / (double(reservedSize()) + 1e-16)) * 100.0; + } + + ASMJIT_INLINE_NODEBUG double unusedSizeAsPercent() const noexcept { + return (double(unusedSize()) / (double(reservedSize()) + 1e-16)) * 100.0; + } + + ASMJIT_INLINE_NODEBUG double overheadSizeAsPercent() const noexcept { + return (double(overheadSize()) / (double(reservedSize()) + 1e-16)) * 100.0; + } + }; + + //! Returns JIT allocator statistics. + //! + //! \remarks This function is thread-safe. + ASMJIT_API Statistics statistics() const noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/jitruntime.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/jitruntime.h new file mode 100644 index 0000000000000000000000000000000000000000..717a6b58d6bb2597989a449e95c00a8e56a8cbc6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/jitruntime.h @@ -0,0 +1,101 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_JITRUNTIME_H_INCLUDED +#define ASMJIT_CORE_JITRUNTIME_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_JIT + +#include "../core/codeholder.h" +#include "../core/jitallocator.h" +#include "../core/target.h" + +ASMJIT_BEGIN_NAMESPACE + +class CodeHolder; + +//! \addtogroup asmjit_virtual_memory +//! \{ + +//! JIT execution runtime is a special `Target` that is designed to store and execute a generated code. +//! +//! JIT runtime is the easiest way of using AsmJit as it abstracts allocation and deallocation of virtual memory +//! where executable code can be placed and from which it can be executed as well. +class ASMJIT_VIRTAPI JitRuntime : public Target { +public: + ASMJIT_NONCOPYABLE(JitRuntime) + + //! Virtual memory allocator. + JitAllocator _allocator; + + //! \name Construction & Destruction + //! \{ + + //! Creates a `JitRuntime` instance. + ASMJIT_API explicit JitRuntime(const JitAllocator::CreateParams* params = nullptr) noexcept; + //! Destroys the `JitRuntime` instance. + ASMJIT_API ~JitRuntime() noexcept override; + + //! \} + + //! \name Accessors + //! \{ + + //! Resets the \ref JitRuntime, freeing everything that was allocated by it. + //! + //! Depending on `resetPolicy` the currently held memory can be either freed entirely when ResetPolicy::kHard is used, + //! or the allocator can keep some of it for next allocations when ResetPolicy::kSoft is used, which is the default + //! behavior. + ASMJIT_INLINE_NODEBUG void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept { + _allocator.reset(resetPolicy); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the associated `JitAllocator`. + ASMJIT_INLINE_NODEBUG JitAllocator* allocator() const noexcept { return const_cast(&_allocator); } + + //! \} + + //! \name Utilities + //! \{ + + // NOTE: To allow passing function pointers to `add()` and `release()` the + // virtual methods are prefixed with `_` and called from templates instead. + + //! Allocates memory needed for a code stored in the `CodeHolder` and relocates the code to the pointer allocated. + //! + //! The beginning of the memory allocated for the function is returned in `dst`. If failed `Error` code is returned + //! and `dst` is explicitly set to `nullptr` (this means that you don't have to set it to null before calling `add()`). + template + ASMJIT_INLINE_NODEBUG Error add(Func* dst, CodeHolder* code) noexcept { + return _add(Support::ptr_cast_impl(dst), code); + } + + //! Releases `p` which was obtained by calling `add()`. + template + ASMJIT_INLINE_NODEBUG Error release(Func p) noexcept { + return _release(Support::ptr_cast_impl(p)); + } + + //! Type-unsafe version of `add()`. + ASMJIT_API virtual Error _add(void** dst, CodeHolder* code) noexcept; + + //! Type-unsafe version of `release()`. + ASMJIT_API virtual Error _release(void* p) noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/logger.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/logger.h new file mode 100644 index 0000000000000000000000000000000000000000..54c169f52fed2284568a286eb3ea040d579380ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/logger.h @@ -0,0 +1,198 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_LOGGING_H_INCLUDED +#define ASMJIT_CORE_LOGGING_H_INCLUDED + +#include "../core/inst.h" +#include "../core/string.h" +#include "../core/formatter.h" + +#ifndef ASMJIT_NO_LOGGING + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_logging +//! \{ + +//! Logging interface. +//! +//! This class can be inherited and reimplemented to fit into your own logging needs. When reimplementing a logger +//! use \ref Logger::_log() method to log customize the output. +//! +//! There are two `Logger` implementations offered by AsmJit: +//! - \ref FileLogger - logs into a `FILE*`. +//! - \ref StringLogger - concatenates all logs into a \ref String. +class ASMJIT_VIRTAPI Logger { +public: + ASMJIT_BASE_CLASS(Logger) + ASMJIT_NONCOPYABLE(Logger) + + //! Format options. + FormatOptions _options; + + //! \name Construction & Destruction + //! \{ + + //! Creates a `Logger` instance. + ASMJIT_API Logger() noexcept; + //! Destroys the `Logger` instance. + ASMJIT_API virtual ~Logger() noexcept; + + //! \} + + //! \name Format Options + //! \{ + + //! Returns \ref FormatOptions of this logger. + ASMJIT_INLINE_NODEBUG FormatOptions& options() noexcept { return _options; } + //! \overload + ASMJIT_INLINE_NODEBUG const FormatOptions& options() const noexcept { return _options; } + //! Sets formatting options of this Logger to `options`. + ASMJIT_INLINE_NODEBUG void setOptions(const FormatOptions& options) noexcept { _options = options; } + //! Resets formatting options of this Logger to defaults. + ASMJIT_INLINE_NODEBUG void resetOptions() noexcept { _options.reset(); } + + //! Returns formatting flags. + ASMJIT_INLINE_NODEBUG FormatFlags flags() const noexcept { return _options.flags(); } + //! Tests whether the logger has the given `flag` enabled. + ASMJIT_INLINE_NODEBUG bool hasFlag(FormatFlags flag) const noexcept { return _options.hasFlag(flag); } + //! Sets formatting flags to `flags`. + ASMJIT_INLINE_NODEBUG void setFlags(FormatFlags flags) noexcept { _options.setFlags(flags); } + //! Enables the given formatting `flags`. + ASMJIT_INLINE_NODEBUG void addFlags(FormatFlags flags) noexcept { _options.addFlags(flags); } + //! Disables the given formatting `flags`. + ASMJIT_INLINE_NODEBUG void clearFlags(FormatFlags flags) noexcept { _options.clearFlags(flags); } + + //! Returns indentation of a given indentation `group`. + ASMJIT_INLINE_NODEBUG uint32_t indentation(FormatIndentationGroup type) const noexcept { return _options.indentation(type); } + //! Sets indentation of the given indentation `group` to `n` spaces. + ASMJIT_INLINE_NODEBUG void setIndentation(FormatIndentationGroup type, uint32_t n) noexcept { _options.setIndentation(type, n); } + //! Resets indentation of the given indentation `group` to 0 spaces. + ASMJIT_INLINE_NODEBUG void resetIndentation(FormatIndentationGroup type) noexcept { _options.resetIndentation(type); } + + //! Returns padding of a given padding `group`. + ASMJIT_INLINE_NODEBUG size_t padding(FormatPaddingGroup type) const noexcept { return _options.padding(type); } + //! Sets padding of a given padding `group` to `n`. + ASMJIT_INLINE_NODEBUG void setPadding(FormatPaddingGroup type, uint32_t n) noexcept { _options.setPadding(type, n); } + //! Resets padding of a given padding `group` to 0, which means that a default will be used. + ASMJIT_INLINE_NODEBUG void resetPadding(FormatPaddingGroup type) noexcept { _options.resetPadding(type); } + + //! \} + + //! \name Logging Interface + //! \{ + + //! Logs `str` - must be reimplemented. + //! + //! The function can accept either a null terminated string if `size` is `SIZE_MAX` or a non-null terminated + //! string of the given `size`. The function cannot assume that the data is null terminated and must handle + //! non-null terminated inputs. + ASMJIT_API virtual Error _log(const char* data, size_t size) noexcept; + + //! Logs string `str`, which is either null terminated or having size `size`. + ASMJIT_INLINE_NODEBUG Error log(const char* data, size_t size = SIZE_MAX) noexcept { return _log(data, size); } + //! Logs content of a string `str`. + ASMJIT_INLINE_NODEBUG Error log(const String& str) noexcept { return _log(str.data(), str.size()); } + + //! Formats the message by using `snprintf()` and then passes the formatted string to \ref _log(). + ASMJIT_API Error logf(const char* fmt, ...) noexcept; + + //! Formats the message by using `vsnprintf()` and then passes the formatted string to \ref _log(). + ASMJIT_API Error logv(const char* fmt, va_list ap) noexcept; + + //! \} +}; + +//! Logger that can log to a `FILE*`. +class ASMJIT_VIRTAPI FileLogger : public Logger { +public: + ASMJIT_NONCOPYABLE(FileLogger) + + FILE* _file; + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `FileLogger` that logs to `FILE*`. + ASMJIT_API FileLogger(FILE* file = nullptr) noexcept; + //! Destroys the `FileLogger`. + ASMJIT_API ~FileLogger() noexcept override; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the logging output stream or null if the logger has no output stream. + ASMJIT_INLINE_NODEBUG FILE* file() const noexcept { return _file; } + + //! Sets the logging output stream to `stream` or null. + //! + //! \note If the `file` is null the logging will be disabled. When a logger is attached to `CodeHolder` or any + //! emitter the logging API will always be called regardless of the output file. This means that if you really + //! want to disable logging at emitter level you must not attach a logger to it. + ASMJIT_INLINE_NODEBUG void setFile(FILE* file) noexcept { _file = file; } + + //! \} + + ASMJIT_API Error _log(const char* data, size_t size = SIZE_MAX) noexcept override; +}; + +//! Logger that stores everything in an internal string buffer. +class ASMJIT_VIRTAPI StringLogger : public Logger { +public: + ASMJIT_NONCOPYABLE(StringLogger) + + //! Logger data as string. + String _content; + + //! \name Construction & Destruction + //! \{ + + //! Create new `StringLogger`. + ASMJIT_API StringLogger() noexcept; + //! Destroys the `StringLogger`. + ASMJIT_API ~StringLogger() noexcept override; + + //! \} + + //! \name Logger Data Accessors + //! \{ + + //! Returns the content of the logger as \ref String. + //! + //! It can be moved, if desired. + ASMJIT_INLINE_NODEBUG String& content() noexcept { return _content; } + //! \overload + ASMJIT_INLINE_NODEBUG const String& content() const noexcept { return _content; } + + //! Returns aggregated logger data as `char*` pointer. + //! + //! The pointer is owned by `StringLogger`, it can't be modified or freed. + ASMJIT_INLINE_NODEBUG const char* data() const noexcept { return _content.data(); } + //! Returns size of the data returned by `data()`. + ASMJIT_INLINE_NODEBUG size_t dataSize() const noexcept { return _content.size(); } + + //! \} + + //! \name Logger Data Manipulation + //! \{ + + //! Clears the accumulated logger data. + ASMJIT_INLINE_NODEBUG void clear() noexcept { _content.clear(); } + + //! \} + + ASMJIT_API Error _log(const char* data, size_t size = SIZE_MAX) noexcept override; +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif + +#endif // ASMJIT_CORE_LOGGER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/operand.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/operand.h new file mode 100644 index 0000000000000000000000000000000000000000..8a1eee79b2ff529657fd44dfe12f7f5a1d4600f6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/operand.h @@ -0,0 +1,1889 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_OPERAND_H_INCLUDED +#define ASMJIT_CORE_OPERAND_H_INCLUDED + +#include "../core/archcommons.h" +#include "../core/support.h" +#include "../core/type.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_assembler +//! \{ + +//! Operand type used by \ref Operand_. +enum class OperandType : uint32_t { + //! Not an operand or not initialized. + kNone = 0, + //! Operand is a register. + kReg = 1, + //! Operand is a memory. + kMem = 2, + //! Operand is a register-list. + kRegList = 3, + //! Operand is an immediate value. + kImm = 4, + //! Operand is a label. + kLabel = 5, + + //! Maximum value of `OperandType`. + kMaxValue = kRegList +}; + +static_assert(uint32_t(OperandType::kMem) == uint32_t(OperandType::kReg) + 1, + "AsmJit requires that `OperandType::kMem` equals to `OperandType::kReg + 1`"); + +//! Register mask is a convenience typedef that describes a mask where each bit describes a physical register id +//! in the same \ref RegGroup. At the moment 32 bits are enough as AsmJit doesn't support any architecture that +//! would provide more than 32 registers for a register group. +typedef uint32_t RegMask; + +//! Register type. +//! +//! Provides a unique type that can be used to identify a register or its view. +enum class RegType : uint8_t { + //! No register - unused, invalid, multiple meanings. + kNone = 0, + + //! This is not a register type. This value is reserved for a \ref Label that's used in \ref BaseMem as a base. + //! + //! Label tag is used as a sub-type, forming a unique signature across all operand types as 0x1 is never associated + //! with any register type. This means that a memory operand's BASE register can be constructed from virtually any + //! operand (register vs. label) by just assigning its type (register type or label-tag) and operand id. + kLabelTag = 1, + + //! Universal type describing program counter (PC) or instruction pointer (IP) register, if the target architecture + //! actually exposes it as a separate register type, which most modern architectures do. + kPC = 2, + + //! 8-bit low general purpose register (X86). + kGp8Lo = 3, + //! 8-bit high general purpose register (X86). + kGp8Hi = 4, + //! 16-bit general purpose register (X86). + kGp16 = 5, + //! 32-bit general purpose register (X86|AArch32|AArch64). + kGp32 = 6, + //! 64-bit general purpose register (X86|AArch64). + kGp64 = 7, + //! 8-bit view of a vector register (AArch64). + kVec8 = 8, + //! 16-bit view of a vector register (AArch64). + kVec16 = 9, + //! 32-bit view of a vector register (AArch32|AArch64). + kVec32 = 10, + //! 64-bit view of a vector register (AArch32|AArch64). + //! + //! \note This is never used for MMX registers on X86, MMX registers have its own category. + kVec64 = 11, + //! 128-bit view of a vector register (X86|AArch32|AArch64). + kVec128 = 12, + //! 256-bit view of a vector register (X86). + kVec256 = 13, + //! 512-bit view of a vector register (X86). + kVec512 = 14, + //! 1024-bit view of a vector register (future). + kVec1024 = 15, + //! View of a vector register, which width is implementation specific (AArch64). + kVecNLen = 16, + + //! Mask register (X86). + kMask = 17, + + //! Start of architecture dependent register types. + kExtra = 18, + + // X86 Specific Register Types + // --------------------------- + + //! Instruction pointer (RIP), only addressable in \ref x86::Mem in 64-bit targets. + kX86_Rip = kPC, + //! Low GPB register (AL, BL, CL, DL, ...). + kX86_GpbLo = kGp8Lo, + //! High GPB register (AH, BH, CH, DH only). + kX86_GpbHi = kGp8Hi, + //! GPW register. + kX86_Gpw = kGp16, + //! GPD register. + kX86_Gpd = kGp32, + //! GPQ register (64-bit). + kX86_Gpq = kGp64, + //! XMM register (SSE+). + kX86_Xmm = kVec128, + //! YMM register (AVX+). + kX86_Ymm = kVec256, + //! ZMM register (AVX512+). + kX86_Zmm = kVec512, + //! K register (AVX512+). + kX86_KReg = kMask, + //! MMX register. + kX86_Mm = kExtra + 0, + //! Segment register (None, ES, CS, SS, DS, FS, GS). + kX86_SReg = kExtra + 1, + //! Control register (CR). + kX86_CReg = kExtra + 2, + //! Debug register (DR). + kX86_DReg = kExtra + 3, + //! FPU (x87) register. + kX86_St = kExtra + 4, + //! Bound register (BND). + kX86_Bnd = kExtra + 5, + //! TMM register (AMX_TILE) + kX86_Tmm = kExtra + 6, + + // ARM Specific Register Types + // --------------------------- + + //! Program pointer (PC) register (AArch64). + kARM_PC = kPC, + //! 32-bit general purpose register (R or W). + kARM_GpW = kGp32, + //! 64-bit general purpose register (X). + kARM_GpX = kGp64, + //! 8-bit view of VFP/ASIMD register (B). + kARM_VecB = kVec8, + //! 16-bit view of VFP/ASIMD register (H). + kARM_VecH = kVec16, + //! 32-bit view of VFP/ASIMD register (S). + kARM_VecS = kVec32, + //! 64-bit view of VFP/ASIMD register (D). + kARM_VecD = kVec64, + //! 128-bit view of VFP/ASIMD register (Q). + kARM_VecQ = kVec128, + //! 128-bit view of VFP/ASIMD register (V). + kARM_VecV = kVec128, + + //! Maximum value of `RegType`. + kMaxValue = 31 +}; +ASMJIT_DEFINE_ENUM_COMPARE(RegType) + +//! Register group. +//! +//! Provides a unique value that identifies groups of registers and their views. +enum class RegGroup : uint8_t { + //! General purpose register group compatible with all backends. + kGp = 0, + //! Vector register group compatible with all backends. + //! + //! Describes X86 XMM|YMM|ZMM registers ARM/AArch64 V registers. + kVec = 1, + + //! Mask register group compatible with all backends that can use masking. + kMask = 2, + //! Extra virtual group #3 that can be used by Compiler for register allocation. + kExtraVirt3 = 3, + + //! Program counter group. + kPC = 4, + + //! Extra non-virtual group that can be used by registers not managed by Compiler. + kExtraNonVirt = 5, + + // X86 Specific Register Groups + // ---------------------------- + + //! K register group (KReg) - maps to \ref RegGroup::kMask (X86, X86_64). + kX86_K = kMask, + //! MMX register group (MM) - maps to \ref RegGroup::kExtraVirt3 (X86, X86_64). + kX86_MM = kExtraVirt3, + + //! Instruction pointer (X86, X86_64). + kX86_Rip = kPC, + //! Segment register group (X86, X86_64). + kX86_SReg = kExtraNonVirt + 0, + //! CR register group (X86, X86_64). + kX86_CReg = kExtraNonVirt + 1, + //! DR register group (X86, X86_64). + kX86_DReg = kExtraNonVirt + 2, + //! FPU register group (X86, X86_64). + kX86_St = kExtraNonVirt + 3, + //! BND register group (X86, X86_64). + kX86_Bnd = kExtraNonVirt + 4, + //! TMM register group (X86, X86_64). + kX86_Tmm = kExtraNonVirt + 5, + + //! First group - only used in loops. + k0 = 0, + //! Last value of a virtual register that is managed by \ref BaseCompiler. + kMaxVirt = Globals::kNumVirtGroups - 1, + //! Maximum value of `RegGroup`. + kMaxValue = 15 +}; +ASMJIT_DEFINE_ENUM_COMPARE(RegGroup) + +typedef Support::EnumValues RegGroupVirtValues; + +//! Operand signature is a 32-bit number describing \ref Operand and some of its payload. +//! +//! In AsmJit operand signature is used to store additional payload of register, memory, and immediate operands. +//! In practice the biggest pressure on OperandSignature is from \ref BaseMem and architecture specific memory +//! operands that need to store additional payload that cannot be stored elsewhere as values of all other members +//! are fully specified by \ref BaseMem. +struct OperandSignature { + //! \name Constants + //! \{ + + enum : uint32_t { + // Operand type (3 least significant bits). + // |........|........|........|.....XXX| + kOpTypeShift = 0, + kOpTypeMask = 0x07u << kOpTypeShift, + + // Register type (5 bits). + // |........|........|........|XXXXX...| + kRegTypeShift = 3, + kRegTypeMask = 0x1Fu << kRegTypeShift, + + // Register group (4 bits). + // |........|........|....XXXX|........| + kRegGroupShift = 8, + kRegGroupMask = 0x0Fu << kRegGroupShift, + + // Memory base type (5 bits). + // |........|........|........|XXXXX...| + kMemBaseTypeShift = 3, + kMemBaseTypeMask = 0x1Fu << kMemBaseTypeShift, + + // Memory index type (5 bits). + // |........|........|...XXXXX|........| + kMemIndexTypeShift = 8, + kMemIndexTypeMask = 0x1Fu << kMemIndexTypeShift, + + // Memory base+index combined (10 bits). + // |........|........|...XXXXX|XXXXX...| + kMemBaseIndexShift = 3, + kMemBaseIndexMask = 0x3FFu << kMemBaseIndexShift, + + // This memory operand represents a home-slot or stack (Compiler) (1 bit). + // |........|........|..X.....|........| + kMemRegHomeShift = 13, + kMemRegHomeFlag = 0x01u << kMemRegHomeShift, + + // Immediate type (1 bit). + // |........|........|........|....X...| + kImmTypeShift = 3, + kImmTypeMask = 0x01u << kImmTypeShift, + + // Predicate used by either registers or immediate values (4 bits). + // |........|XXXX....|........|........| + kPredicateShift = 20, + kPredicateMask = 0x0Fu << kPredicateShift, + + // Operand size (8 most significant bits). + // |XXXXXXXX|........|........|........| + kSizeShift = 24, + kSizeMask = 0xFFu << kSizeShift + }; + + //! \} + + //! \name Members + //! \{ + + uint32_t _bits; + + //! \} + + //! \name Overloaded Operators + //! + //! Overloaded operators make `OperandSignature` behave like regular integer. + //! + //! \{ + + ASMJIT_INLINE_NODEBUG constexpr bool operator!() const noexcept { return _bits == 0; } + ASMJIT_INLINE_NODEBUG constexpr explicit operator bool() const noexcept { return _bits != 0; } + + ASMJIT_INLINE_NODEBUG OperandSignature& operator|=(uint32_t x) noexcept { _bits |= x; return *this; } + ASMJIT_INLINE_NODEBUG OperandSignature& operator&=(uint32_t x) noexcept { _bits &= x; return *this; } + ASMJIT_INLINE_NODEBUG OperandSignature& operator^=(uint32_t x) noexcept { _bits ^= x; return *this; } + + ASMJIT_INLINE_NODEBUG OperandSignature& operator|=(const OperandSignature& other) noexcept { return operator|=(other._bits); } + ASMJIT_INLINE_NODEBUG OperandSignature& operator&=(const OperandSignature& other) noexcept { return operator&=(other._bits); } + ASMJIT_INLINE_NODEBUG OperandSignature& operator^=(const OperandSignature& other) noexcept { return operator^=(other._bits); } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator~() const noexcept { return OperandSignature{~_bits}; } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator|(uint32_t x) const noexcept { return OperandSignature{_bits | x}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator&(uint32_t x) const noexcept { return OperandSignature{_bits & x}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator^(uint32_t x) const noexcept { return OperandSignature{_bits ^ x}; } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator|(const OperandSignature& other) const noexcept { return OperandSignature{_bits | other._bits}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator&(const OperandSignature& other) const noexcept { return OperandSignature{_bits & other._bits}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator^(const OperandSignature& other) const noexcept { return OperandSignature{_bits ^ other._bits}; } + + ASMJIT_INLINE_NODEBUG constexpr bool operator==(uint32_t x) const noexcept { return _bits == x; } + ASMJIT_INLINE_NODEBUG constexpr bool operator!=(uint32_t x) const noexcept { return _bits != x; } + + ASMJIT_INLINE_NODEBUG constexpr bool operator==(const OperandSignature& other) const noexcept { return _bits == other._bits; } + ASMJIT_INLINE_NODEBUG constexpr bool operator!=(const OperandSignature& other) const noexcept { return _bits != other._bits; } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG void reset() noexcept { _bits = 0; } + + ASMJIT_INLINE_NODEBUG constexpr uint32_t bits() const noexcept { return _bits; } + ASMJIT_INLINE_NODEBUG void setBits(uint32_t bits) noexcept { _bits = bits; } + + template + ASMJIT_INLINE_NODEBUG constexpr bool hasField() const noexcept { + return (_bits & kFieldMask) != 0; + } + + template + ASMJIT_INLINE_NODEBUG constexpr bool hasField(uint32_t value) const noexcept { + return (_bits & kFieldMask) != value << Support::ConstCTZ::value; + } + + template + ASMJIT_INLINE_NODEBUG constexpr uint32_t getField() const noexcept { + return (_bits >> Support::ConstCTZ::value) & (kFieldMask >> Support::ConstCTZ::value); + } + + template + ASMJIT_INLINE_NODEBUG void setField(uint32_t value) noexcept { + ASMJIT_ASSERT(((value << Support::ConstCTZ::value) & ~kFieldMask) == 0); + _bits = (_bits & ~kFieldMask) | (value << Support::ConstCTZ::value); + } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature subset(uint32_t mask) const noexcept { return OperandSignature{_bits & mask}; } + + template::value> + ASMJIT_INLINE_NODEBUG constexpr OperandSignature replacedValue(uint32_t value) const noexcept { return OperandSignature{(_bits & ~kFieldMask) | (value << kFieldShift)}; } + + template + ASMJIT_INLINE_NODEBUG constexpr bool matchesSignature(const OperandSignature& signature) const noexcept { + return (_bits & kFieldMask) == signature._bits; + } + + template + ASMJIT_INLINE_NODEBUG constexpr bool matchesFields(uint32_t bits) const noexcept { + return (_bits & kFieldMask) == bits; + } + + template + ASMJIT_INLINE_NODEBUG constexpr bool matchesFields(const OperandSignature& fields) const noexcept { + return (_bits & kFieldMask) == fields._bits; + } + + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return _bits != 0; } + + ASMJIT_INLINE_NODEBUG constexpr OperandType opType() const noexcept { return (OperandType)getField(); } + + ASMJIT_INLINE_NODEBUG constexpr RegType regType() const noexcept { return (RegType)getField(); } + ASMJIT_INLINE_NODEBUG constexpr RegGroup regGroup() const noexcept { return (RegGroup)getField(); } + + ASMJIT_INLINE_NODEBUG constexpr RegType memBaseType() const noexcept { return (RegType)getField(); } + ASMJIT_INLINE_NODEBUG constexpr RegType memIndexType() const noexcept { return (RegType)getField(); } + + ASMJIT_INLINE_NODEBUG constexpr uint32_t predicate() const noexcept { return getField(); } + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return getField(); } + + ASMJIT_INLINE_NODEBUG void setOpType(OperandType opType) noexcept { setField(uint32_t(opType)); } + ASMJIT_INLINE_NODEBUG void setRegType(RegType regType) noexcept { setField(uint32_t(regType)); } + ASMJIT_INLINE_NODEBUG void setRegGroup(RegGroup regGroup) noexcept { setField(uint32_t(regGroup)); } + + ASMJIT_INLINE_NODEBUG void setMemBaseType(RegType baseType) noexcept { setField(uint32_t(baseType)); } + ASMJIT_INLINE_NODEBUG void setMemIndexType(RegType indexType) noexcept { setField(uint32_t(indexType)); } + + ASMJIT_INLINE_NODEBUG void setPredicate(uint32_t predicate) noexcept { setField(predicate); } + ASMJIT_INLINE_NODEBUG void setSize(uint32_t size) noexcept { setField(size); } + + //! \} + + //! \name Static Constructors + //! \{ + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromBits(uint32_t bits) noexcept { + return OperandSignature{bits}; + } + + template + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromValue(const T& value) noexcept { + return OperandSignature{uint32_t(value) << Support::ConstCTZ::value}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromOpType(OperandType opType) noexcept { + return OperandSignature{uint32_t(opType) << kOpTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromRegType(RegType regType) noexcept { + return OperandSignature{uint32_t(regType) << kRegTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromRegGroup(RegGroup regGroup) noexcept { + return OperandSignature{uint32_t(regGroup) << kRegGroupShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromRegTypeAndGroup(RegType regType, RegGroup regGroup) noexcept { + return fromRegType(regType) | fromRegGroup(regGroup); + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromMemBaseType(RegType baseType) noexcept { + return OperandSignature{uint32_t(baseType) << kMemBaseTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromMemIndexType(RegType indexType) noexcept { + return OperandSignature{uint32_t(indexType) << kMemIndexTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromPredicate(uint32_t predicate) noexcept { + return OperandSignature{predicate << kPredicateShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromSize(uint32_t size) noexcept { + return OperandSignature{size << kSizeShift}; + } + + //! \} +}; + +//! Base class representing an operand in AsmJit (non-default constructed version). +//! +//! Contains no initialization code and can be used safely to define an array of operands that won't be initialized. +//! This is a \ref Operand base structure designed to be statically initialized, static const, or to be used by user +//! code to define an array of operands without having them default initialized at construction time. +//! +//! The key difference between \ref Operand and \ref Operand_ is: +//! +//! ``` +//! Operand_ xArray[10]; // Not initialized, contains garbage. +//! Operand_ yArray[10] {}; // All operands initialized to none explicitly (zero initialized). +//! Operand yArray[10]; // All operands initialized to none implicitly (zero initialized). +//! ``` +struct Operand_ { + //! \name Types + //! \{ + + typedef OperandSignature Signature; + + //! \} + + //! \name Constants + //! \{ + + // Indexes to `_data` array. + enum DataIndex : uint32_t { + kDataMemIndexId = 0, + kDataMemOffsetLo = 1, + + kDataImmValueLo = ASMJIT_ARCH_LE ? 0 : 1, + kDataImmValueHi = ASMJIT_ARCH_LE ? 1 : 0 + }; + + //! Constants useful for VirtId <-> Index translation. + enum VirtIdConstants : uint32_t { + //! Minimum valid packed-id. + kVirtIdMin = 256, + //! Maximum valid packed-id, excludes Globals::kInvalidId. + kVirtIdMax = Globals::kInvalidId - 1, + //! Count of valid packed-ids. + kVirtIdCount = uint32_t(kVirtIdMax - kVirtIdMin + 1) + }; + + //! \} + + //! \name Members + //! \{ + + //! Provides operand type and additional payload. + Signature _signature; + //! Either base id as used by memory operand or any id as used by others. + uint32_t _baseId; + + //! Data specific to the operand type. + //! + //! The reason we don't use union is that we have `constexpr` constructors that construct operands and other + //!`constexpr` functions that return whether another Operand or something else. These cannot generally work with + //! unions so we also cannot use `union` if we want to be standard compliant. + uint32_t _data[2]; + + //! \} + + //! Tests whether the given `id` is a valid virtual register id. Since AsmJit supports both physical and virtual + //! registers it must be able to distinguish between these two. The idea is that physical registers are always + //! limited in size, so virtual identifiers start from `kVirtIdMin` and end at `kVirtIdMax`. + static ASMJIT_INLINE_NODEBUG bool isVirtId(uint32_t id) noexcept { return id - kVirtIdMin < uint32_t(kVirtIdCount); } + //! Converts a real-id into a packed-id that can be stored in Operand. + static ASMJIT_INLINE_NODEBUG uint32_t indexToVirtId(uint32_t id) noexcept { return id + kVirtIdMin; } + //! Converts a packed-id back to real-id. + static ASMJIT_INLINE_NODEBUG uint32_t virtIdToIndex(uint32_t id) noexcept { return id - kVirtIdMin; } + + //! \name Construction & Destruction + //! \{ + + //! \cond INTERNAL + //! Initializes a `BaseReg` operand from `signature` and register `id`. + ASMJIT_INLINE_NODEBUG void _initReg(const Signature& signature, uint32_t id) noexcept { + _signature = signature; + _baseId = id; + _data[0] = 0; + _data[1] = 0; + } + //! \endcond + + //! Initializes the operand from `other` operand (used by operator overloads). + ASMJIT_INLINE_NODEBUG void copyFrom(const Operand_& other) noexcept { + _signature._bits = other._signature._bits; + _baseId = other._baseId; + _data[0] = other._data[0]; + _data[1] = other._data[1]; + } + + //! Resets the `Operand` to none. + //! + //! None operand is defined the following way: + //! - Its signature is zero (OperandType::kNone, and the rest zero as well). + //! - Its id is `0`. + //! - The reserved8_4 field is set to `0`. + //! - The reserved12_4 field is set to zero. + //! + //! In other words, reset operands have all members set to zero. Reset operand must match the Operand state + //! right after its construction. Alternatively, if you have an array of operands, you can simply use `memset()`. + //! + //! ``` + //! using namespace asmjit; + //! + //! Operand a; + //! Operand b; + //! assert(a == b); + //! + //! b = x86::eax; + //! assert(a != b); + //! + //! b.reset(); + //! assert(a == b); + //! + //! memset(&b, 0, sizeof(Operand)); + //! assert(a == b); + //! ``` + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _signature.reset(); + _baseId = 0; + _data[0] = 0; + _data[1] = 0; + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Tests whether this operand is the same as `other`. + ASMJIT_INLINE_NODEBUG constexpr bool operator==(const Operand_& other) const noexcept { return equals(other); } + //! Tests whether this operand is not the same as `other`. + ASMJIT_INLINE_NODEBUG constexpr bool operator!=(const Operand_& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Cast + //! \{ + + //! Casts this operand to `T` type. + template + ASMJIT_INLINE_NODEBUG T& as() noexcept { return static_cast(*this); } + + //! Casts this operand to `T` type (const). + template + ASMJIT_INLINE_NODEBUG const T& as() const noexcept { return static_cast(*this); } + + //! \} + + //! \name Equality + //! \{ + + //! Tests whether the operand is 100% equal to `other` operand. + //! + //! \note This basically performs a binary comparison, if aby bit is + //! different the operands are not equal. + ASMJIT_INLINE_NODEBUG constexpr bool equals(const Operand_& other) const noexcept { + return bool(unsigned(_signature == other._signature) & + unsigned(_baseId == other._baseId ) & + unsigned(_data[0] == other._data[0] ) & + unsigned(_data[1] == other._data[1] )); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the operand's signature matches the signature of the `other` operand. + ASMJIT_INLINE_NODEBUG constexpr bool hasSignature(const Operand_& other) const noexcept { return _signature == other._signature; } + //! Tests whether the operand's signature matches the given signature `sign`. + ASMJIT_INLINE_NODEBUG constexpr bool hasSignature(const Signature& other) const noexcept { return _signature == other; } + + //! Returns operand signature as unsigned 32-bit integer. + //! + //! Signature is first 4 bytes of the operand data. It's used mostly for operand checking as it's + //! much faster to check packed 4 bytes at once than having to check these bytes individually. + ASMJIT_INLINE_NODEBUG constexpr Signature signature() const noexcept { return _signature; } + + //! Sets the operand signature, see `signature()`. + //! + //! \note Improper use of `setSignature()` can lead to hard-to-debug errors. + ASMJIT_INLINE_NODEBUG void setSignature(const Signature& signature) noexcept { _signature = signature; } + //! \overload + ASMJIT_INLINE_NODEBUG void setSignature(uint32_t signature) noexcept { _signature._bits = signature; } + + //! Returns the type of the operand, see `OpType`. + ASMJIT_INLINE_NODEBUG constexpr OperandType opType() const noexcept { return _signature.opType(); } + //! Tests whether the operand is none (`OperandType::kNone`). + ASMJIT_INLINE_NODEBUG constexpr bool isNone() const noexcept { return _signature == Signature::fromBits(0); } + //! Tests whether the operand is a register (`OperandType::kReg`). + ASMJIT_INLINE_NODEBUG constexpr bool isReg() const noexcept { return opType() == OperandType::kReg; } + //! Tests whether the operand is a register-list. + //! + //! \note Register-list is currently only used by 32-bit ARM architecture. + ASMJIT_INLINE_NODEBUG constexpr bool isRegList() const noexcept { return opType() == OperandType::kRegList; } + //! Tests whether the operand is a memory location (`OperandType::kMem`). + ASMJIT_INLINE_NODEBUG constexpr bool isMem() const noexcept { return opType() == OperandType::kMem; } + //! Tests whether the operand is an immediate (`OperandType::kImm`). + ASMJIT_INLINE_NODEBUG constexpr bool isImm() const noexcept { return opType() == OperandType::kImm; } + //! Tests whether the operand is a label (`OperandType::kLabel`). + ASMJIT_INLINE_NODEBUG constexpr bool isLabel() const noexcept { return opType() == OperandType::kLabel; } + + //! Tests whether the operand is a physical register. + ASMJIT_INLINE_NODEBUG constexpr bool isPhysReg() const noexcept { return isReg() && _baseId < 0xFFu; } + //! Tests whether the operand is a virtual register. + ASMJIT_INLINE_NODEBUG constexpr bool isVirtReg() const noexcept { return isReg() && _baseId > 0xFFu; } + + //! Returns the operand id. + //! + //! The value returned should be interpreted accordingly to the operand type: + //! * None - Should be `0`. + //! * Reg - Physical or virtual register id. + //! * Mem - Multiple meanings - BASE address (register or label id), or high value of a 64-bit absolute address. + //! * Imm - Should be `0`. + //! * Label - Label id if it was created by using `newLabel()` or `Globals::kInvalidId` if the label is invalid or + //! not initialized. + ASMJIT_INLINE_NODEBUG constexpr uint32_t id() const noexcept { return _baseId; } + + //! Tests whether the operand is a register matching the given register `type`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType type) const noexcept { + return _signature.subset(Signature::kOpTypeMask | Signature::kRegTypeMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegType(type)); + } + + //! Tests whether the operand is a register of the provided register group `regGroup`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegGroup regGroup) const noexcept { + return _signature.subset(Signature::kOpTypeMask | Signature::kRegGroupMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(regGroup)); + } + + //! Tests whether the operand is register and of register type `regType` and `regId`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType regType, uint32_t regId) const noexcept { return isReg(regType) && _baseId == regId; } + //! Tests whether the operand is register and of register group `regGroup` and `regId`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegGroup regGroup, uint32_t regId) const noexcept { return isReg(regGroup) && _baseId == regId; } + + //! Tests whether the register is a general purpose register (any size). + ASMJIT_INLINE_NODEBUG constexpr bool isGp() const noexcept { return isReg(RegGroup::kGp); } + //! Tests whether the register is a 32-bit general purpose register. + ASMJIT_INLINE_NODEBUG constexpr bool isGp32() const noexcept { return isReg(RegType::kGp32); } + //! Tests whether the register is a 64-bit general purpose register. + ASMJIT_INLINE_NODEBUG constexpr bool isGp64() const noexcept { return isReg(RegType::kGp64); } + + //! Tests whether the register is a vector register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isVec() const noexcept { return isReg(RegGroup::kVec); } + //! Tests whether the register is an 8-bit vector register or view (AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec8() const noexcept { return isReg(RegType::kVec8); } + //! Tests whether the register is a 16-bit vector register or view (AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec16() const noexcept { return isReg(RegType::kVec16); } + //! Tests whether the register is a 32-bit vector register or view (AArch32, AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec32() const noexcept { return isReg(RegType::kVec32); } + //! Tests whether the register is a 64-bit vector register or view (AArch32, AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec64() const noexcept { return isReg(RegType::kVec64); } + //! Tests whether the register is a 128-bit vector register or view (AArch32, AArch64, X86, X86_64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec128() const noexcept { return isReg(RegType::kVec128); } + //! Tests whether the register is a 256-bit vector register or view (X86, X86_64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec256() const noexcept { return isReg(RegType::kVec256); } + //! Tests whether the register is a 512-bit vector register or view (X86, X86_64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec512() const noexcept { return isReg(RegType::kVec512); } + + //! Tests whether the register is a mask register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isMask() const noexcept { return isReg(RegGroup::kMask); } + + //! Tests whether the operand is a register matching the given register `type`. + ASMJIT_INLINE_NODEBUG constexpr bool isRegList(RegType type) const noexcept { + return _signature.subset(Signature::kOpTypeMask | Signature::kRegTypeMask) == (Signature::fromOpType(OperandType::kRegList) | Signature::fromRegType(type)); + } + + //! Tests whether the operand is a register or memory. + //! + //! \note This is useful on X86 and X86_64 architectures as many instructions support Reg/Mem operand combination. + //! So if the user code works with just \ref Operand, it's possible to check whether the operand is either a register + //! or memory location with a single check. + ASMJIT_INLINE_NODEBUG constexpr bool isRegOrMem() const noexcept { + return Support::isBetween(uint32_t(opType()), uint32_t(OperandType::kReg), uint32_t(OperandType::kMem)); + } + + //! Tests whether the operand is a register, register-list, or memory. + //! + //! \note This is useful on 32-bit ARM architecture to check whether an operand references a register. It can be + //! used in other architectures too, but it would work identically to \ref isRegOrMem() as other architectures + //! don't provide register lists. + ASMJIT_INLINE_NODEBUG constexpr bool isRegOrRegListOrMem() const noexcept { + return Support::isBetween(uint32_t(opType()), uint32_t(OperandType::kReg), uint32_t(OperandType::kRegList)); + } + + //! \} + + //! \name Accessors (X86 Specific) + //! \{ + + //! Returns a size of a register or an X86 memory operand. + //! + //! At the moment only X86 and X86_64 memory operands have a size - other memory operands can use bits that represent + //! size as an additional payload. This means that memory size is architecture specific and should be accessed via + //! \ref x86::Mem::size(). Sometimes when the user knows that the operand is either a register or memory operand this + //! function can be helpful as it avoids casting. + ASMJIT_INLINE_NODEBUG constexpr uint32_t x86RmSize() const noexcept { + return _signature.size(); + } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("hasSize() is no longer portable - use x86RmSize() instead, if your target is X86/X86_64") + ASMJIT_INLINE_NODEBUG constexpr bool hasSize() const noexcept { return x86RmSize() != 0u; } + + ASMJIT_DEPRECATED("hasSize() is no longer portable - use x86RmSize() instead, if your target is X86/X86_64") + ASMJIT_INLINE_NODEBUG constexpr bool hasSize(uint32_t s) const noexcept { return x86RmSize() == s; } + + ASMJIT_DEPRECATED("size() is no longer portable - use x86RmSize() instead, if your target is X86/X86_64") + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } +#endif + + //! \} +}; + +//! Base class representing an operand in AsmJit (default constructed version). +class Operand : public Operand_ { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates `kOpNone` operand having all members initialized to zero. + ASMJIT_INLINE_NODEBUG constexpr Operand() noexcept + : Operand_{ Signature::fromOpType(OperandType::kNone), 0u, { 0u, 0u }} {} + + //! Creates a cloned `other` operand. + ASMJIT_INLINE_NODEBUG constexpr Operand(const Operand& other) noexcept = default; + + //! Creates a cloned `other` operand. + ASMJIT_INLINE_NODEBUG constexpr explicit Operand(const Operand_& other) + : Operand_(other) {} + + //! Creates an operand initialized to raw `[u0, u1, u2, u3]` values. + ASMJIT_INLINE_NODEBUG constexpr Operand(Globals::Init_, const Signature& u0, uint32_t u1, uint32_t u2, uint32_t u3) noexcept + : Operand_{{u0._bits}, u1, {u2, u3}} {} + + //! Creates an uninitialized operand (dangerous). + ASMJIT_INLINE_NODEBUG explicit Operand(Globals::NoInit_) noexcept {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Operand& operator=(const Operand& other) noexcept = default; + ASMJIT_INLINE_NODEBUG Operand& operator=(const Operand_& other) noexcept { return operator=(static_cast(other)); } + + //! \} + + //! \name Clone + //! \{ + + //! Clones this operand and returns its copy. + ASMJIT_INLINE_NODEBUG constexpr Operand clone() const noexcept { return Operand(*this); } + + //! \} +}; + +static_assert(sizeof(Operand) == 16, "asmjit::Operand must be exactly 16 bytes long"); + +//! Label (jump target or data location). +//! +//! Label represents a location in code typically used as a jump target, but may be also a reference to some data or +//! a static variable. Label has to be explicitly created by BaseEmitter. +//! +//! Example of using labels: +//! +//! ``` +//! // Create some emitter (for example x86::Assembler). +//! x86::Assembler a; +//! +//! // Create Label instance. +//! Label L1 = a.newLabel(); +//! +//! // ... your code ... +//! +//! // Using label. +//! a.jump(L1); +//! +//! // ... your code ... +//! +//! // Bind label to the current position, see `BaseEmitter::bind()`. +//! a.bind(L1); +//! ``` +class Label : public Operand { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates a label operand without ID (you must set the ID to make it valid). + ASMJIT_INLINE_NODEBUG constexpr Label() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kLabel), Globals::kInvalidId, 0, 0) {} + + //! Creates a cloned label operand of `other`. + ASMJIT_INLINE_NODEBUG constexpr Label(const Label& other) noexcept + : Operand(other) {} + + //! Creates a label operand of the given `id`. + ASMJIT_INLINE_NODEBUG constexpr explicit Label(uint32_t id) noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kLabel), id, 0, 0) {} + + ASMJIT_INLINE_NODEBUG explicit Label(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! Resets the label, will reset all properties and set its ID to `Globals::kInvalidId`. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _signature = Signature::fromOpType(OperandType::kLabel); + _baseId = Globals::kInvalidId; + _data[0] = 0; + _data[1] = 0; + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Label& operator=(const Label& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the label was created by CodeHolder and/or an attached emitter. + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return _baseId != Globals::kInvalidId; } + //! Sets the label `id`. + ASMJIT_INLINE_NODEBUG void setId(uint32_t id) noexcept { _baseId = id; } + + //! \} +}; + +//! \cond INTERNAL +//! Default register traits. +struct BaseRegTraits { + enum : uint32_t { + //! \ref TypeId representing this register type, could be \ref TypeId::kVoid if such type doesn't exist. + kTypeId = uint32_t(TypeId::kVoid), + //! RegType is not valid by default. + kValid = 0, + + //! Zero type by default (defaults to None). + kType = uint32_t(RegType::kNone), + //! Zero group by default (defaults to GP). + kGroup = uint32_t(RegGroup::kGp), + //! No size by default. + kSize = 0, + + //! Empty signature by default (not even having operand type set to register). + kSignature = 0 + }; +}; +//! \endcond + +//! Physical or virtual register operand (base). +class BaseReg : public Operand { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + //! None or any register (mostly internal). + kIdBad = 0xFFu, + + kBaseSignatureMask = + Signature::kOpTypeMask | + Signature::kRegTypeMask | + Signature::kRegGroupMask | + Signature::kSizeMask, + + kTypeNone = uint32_t(RegType::kNone), + kSignature = Signature::fromOpType(OperandType::kReg).bits() + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a dummy register operand. + ASMJIT_INLINE_NODEBUG constexpr BaseReg() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kReg), kIdBad, 0, 0) {} + + //! Creates a new register operand which is the same as `other` . + ASMJIT_INLINE_NODEBUG constexpr BaseReg(const BaseReg& other) noexcept + : Operand(other) {} + + //! Creates a new register operand compatible with `other`, but with a different `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseReg(const BaseReg& other, uint32_t id) noexcept + : Operand(Globals::Init, other._signature, id, 0, 0) {} + + //! Creates a register initialized to the given `signature` and `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseReg(const Signature& signature, uint32_t id) noexcept + : Operand(Globals::Init, signature, id, 0, 0) {} + + ASMJIT_INLINE_NODEBUG explicit BaseReg(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG BaseReg& operator=(const BaseReg& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns base signature of the register associated with each register type. + //! + //! Base signature only contains the operand type, register type, register group, and register size. It doesn't + //! contain element type, predicate, or other architecture-specific data. Base signature is a signature that is + //! provided by architecture-specific `RegTraits`, like \ref x86::RegTraits. + ASMJIT_INLINE_NODEBUG constexpr OperandSignature baseSignature() const noexcept { return _signature & kBaseSignatureMask; } + + //! Tests whether the operand's base signature matches the given signature `sign`. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseSignature(uint32_t signature) const noexcept { return baseSignature() == signature; } + //! Tests whether the operand's base signature matches the given signature `sign`. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseSignature(const OperandSignature& signature) const noexcept { return baseSignature() == signature; } + //! Tests whether the operand's base signature matches the base signature of the `other` operand. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseSignature(const BaseReg& other) const noexcept { return baseSignature() == other.baseSignature(); } + + //! Tests whether this register is the same as `other`. + //! + //! This is just an optimization. Registers by default only use the first 8 bytes of Operand data, so this method + //! takes advantage of this knowledge and only compares these 8 bytes. If both operands were created correctly + //! both \ref equals() and \ref isSame() should give the same answer, however, if any of these two contains garbage + //! or other metadata in the upper 8 bytes then \ref isSame() may return `true` in cases in which \ref equals() + //! returns false. + ASMJIT_INLINE_NODEBUG constexpr bool isSame(const BaseReg& other) const noexcept { + return (_signature == other._signature) & (_baseId == other._baseId); + } + + //! Tests whether the register is valid (either virtual or physical). + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return bool(unsigned(_signature != 0) & unsigned(_baseId != kIdBad)); } + + //! Tests whether this is a physical register. + ASMJIT_INLINE_NODEBUG constexpr bool isPhysReg() const noexcept { return _baseId < kIdBad; } + //! Tests whether this is a virtual register. + ASMJIT_INLINE_NODEBUG constexpr bool isVirtReg() const noexcept { return _baseId > kIdBad; } + + //! Tests whether the register type matches `type` - same as `isReg(type)`, provided for convenience. + ASMJIT_INLINE_NODEBUG constexpr bool isType(RegType type) const noexcept { return _signature.subset(Signature::kRegTypeMask) == Signature::fromRegType(type); } + //! Tests whether the register group matches `group`. + ASMJIT_INLINE_NODEBUG constexpr bool isGroup(RegGroup group) const noexcept { return _signature.subset(Signature::kRegGroupMask) == Signature::fromRegGroup(group); } + + //! Tests whether the register is a general purpose register (any size). + ASMJIT_INLINE_NODEBUG constexpr bool isGp() const noexcept { return isGroup(RegGroup::kGp); } + //! Tests whether the register is a vector register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isVec() const noexcept { return isGroup(RegGroup::kVec); } + //! Tests whether the register is a mask register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isMask() const noexcept { return isGroup(RegGroup::kMask); } + + using Operand_::isReg; + + //! Same as `isType()`, provided for convenience. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType rType) const noexcept { return isType(rType); } + //! Tests whether the register type matches `type` and register id matches `id`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType rType, uint32_t id) const noexcept { return isType(rType) && this->id() == id; } + + //! Returns the register type. + ASMJIT_INLINE_NODEBUG constexpr RegType type() const noexcept { return _signature.regType(); } + //! Returns the register group. + ASMJIT_INLINE_NODEBUG constexpr RegGroup group() const noexcept { return _signature.regGroup(); } + + //! Tests whether the register specifies a size (i.e. the size is not zero). + ASMJIT_INLINE_NODEBUG constexpr bool hasSize() const noexcept { return _signature.hasField(); } + //! Tests whether the register size matches size `s`. + ASMJIT_INLINE_NODEBUG constexpr bool hasSize(uint32_t s) const noexcept { return size() == s; } + + //! Returns the size of the register in bytes. If the register size depends on architecture (like `x86::CReg` and + //! `x86::DReg`) the size returned should be the greatest possible (so it should return 64-bit size in such case). + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } + + //! Returns operation predicate of the register (ARM/AArch64). + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the register. + ASMJIT_INLINE_NODEBUG constexpr uint32_t predicate() const noexcept { return _signature.getField(); } + + //! Sets operation predicate of the register to `predicate` (ARM/AArch64). + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the register. + ASMJIT_INLINE_NODEBUG void setPredicate(uint32_t predicate) noexcept { _signature.setField(predicate); } + + //! Resets shift operation type of the register to the default value (ARM/AArch64). + ASMJIT_INLINE_NODEBUG void resetPredicate() noexcept { _signature.setField(0); } + + //! Clones the register operand. + ASMJIT_INLINE_NODEBUG constexpr BaseReg clone() const noexcept { return BaseReg(*this); } + + //! Casts this register to `RegT` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegT cloneAs() const noexcept { return RegT(Signature(RegT::kSignature), id()); } + + //! Casts this register to `other` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegT cloneAs(const RegT& other) const noexcept { return RegT(other.signature(), id()); } + + //! Sets the register id to `id`. + ASMJIT_INLINE_NODEBUG void setId(uint32_t id) noexcept { _baseId = id; } + + //! Sets a 32-bit operand signature based on traits of `RegT`. + template + ASMJIT_INLINE_NODEBUG void setSignatureT() noexcept { _signature = RegT::kSignature; } + + //! Sets the register `signature` and `id`. + ASMJIT_INLINE_NODEBUG void setSignatureAndId(const OperandSignature& signature, uint32_t id) noexcept { + _signature = signature; + _baseId = id; + } + + //! \} + + //! \name Static Functions + //! \{ + + //! Tests whether the `op` operand is a general purpose register. + static ASMJIT_INLINE_NODEBUG bool isGp(const Operand_& op) noexcept { + // Check operand type and register group. Not interested in register type and size. + return op.signature().subset(Signature::kOpTypeMask | Signature::kRegGroupMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(RegGroup::kGp)); + } + + //! Tests whether the `op` operand is a vector register. + static ASMJIT_INLINE_NODEBUG bool isVec(const Operand_& op) noexcept { + // Check operand type and register group. Not interested in register type and size. + return op.signature().subset(Signature::kOpTypeMask | Signature::kRegGroupMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(RegGroup::kVec)); + } + + //! Tests whether the `op` is a general purpose register of the given `id`. + static ASMJIT_INLINE_NODEBUG bool isGp(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isGp(op)) & unsigned(op.id() == id)); } + //! Tests whether the `op` is a vector register of the given `id`. + static ASMJIT_INLINE_NODEBUG bool isVec(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVec(op)) & unsigned(op.id() == id)); } + + //! \} +}; + +//! RegOnly is 8-byte version of `BaseReg` that allows to store either register or nothing. +//! +//! It's designed to decrease the space consumed by an extra "operand" in \ref BaseEmitter and \ref InstNode. +struct RegOnly { + //! \name Types + //! \{ + + typedef OperandSignature Signature; + + //! \} + + //! Operand signature - only \ref OperandType::kNone and \ref OperandType::kReg are supported. + Signature _signature; + //! Physical or virtual register id. + uint32_t _id; + + //! \name Construction & Destruction + //! \{ + + //! Initializes the `RegOnly` instance to hold register `signature` and `id`. + ASMJIT_INLINE_NODEBUG void init(const OperandSignature& signature, uint32_t id) noexcept { + _signature = signature; + _id = id; + } + + ASMJIT_INLINE_NODEBUG void init(const BaseReg& reg) noexcept { init(reg.signature(), reg.id()); } + ASMJIT_INLINE_NODEBUG void init(const RegOnly& reg) noexcept { init(reg.signature(), reg.id()); } + + //! Resets the `RegOnly` members to zeros (none). + ASMJIT_INLINE_NODEBUG void reset() noexcept { init(Signature::fromBits(0), 0); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether this ExtraReg is none (same as calling `Operand_::isNone()`). + ASMJIT_INLINE_NODEBUG constexpr bool isNone() const noexcept { return _signature == 0; } + //! Tests whether the register is valid (either virtual or physical). + ASMJIT_INLINE_NODEBUG constexpr bool isReg() const noexcept { return _signature != 0; } + + //! Tests whether this is a physical register. + ASMJIT_INLINE_NODEBUG constexpr bool isPhysReg() const noexcept { return _id < BaseReg::kIdBad; } + //! Tests whether this is a virtual register (used by `BaseCompiler`). + ASMJIT_INLINE_NODEBUG constexpr bool isVirtReg() const noexcept { return _id > BaseReg::kIdBad; } + + //! Returns the register signature or 0 if no register is assigned. + ASMJIT_INLINE_NODEBUG constexpr OperandSignature signature() const noexcept { return _signature; } + //! Returns the register id. + //! + //! \note Always check whether the register is assigned before using the returned identifier as + //! non-assigned `RegOnly` instance would return zero id, which is still a valid register id. + ASMJIT_INLINE_NODEBUG constexpr uint32_t id() const noexcept { return _id; } + + //! Sets the register id. + ASMJIT_INLINE_NODEBUG void setId(uint32_t id) noexcept { _id = id; } + + //! Returns the register type. + ASMJIT_INLINE_NODEBUG constexpr RegType type() const noexcept { return _signature.regType(); } + //! Returns the register group. + ASMJIT_INLINE_NODEBUG constexpr RegGroup group() const noexcept { return _signature.regGroup(); } + + //! \} + + //! \name Utilities + //! \{ + + //! Converts this ExtraReg to a real `RegT` operand. + template + ASMJIT_INLINE_NODEBUG constexpr RegT toReg() const noexcept { return RegT(_signature, _id); } + + //! \} +}; + +//! \cond INTERNAL +//! Adds a template specialization for `REG_TYPE` into the local `RegTraits`. +#define ASMJIT_DEFINE_REG_TRAITS(REG_TYPE, GROUP, SIZE, TYPE_ID) \ +template<> \ +struct RegTraits { \ + static constexpr uint32_t kValid = 1; \ + static constexpr RegType kType = REG_TYPE; \ + static constexpr RegGroup kGroup = GROUP; \ + static constexpr uint32_t kSize = SIZE; \ + static constexpr TypeId kTypeId = TYPE_ID; \ + \ + static constexpr uint32_t kSignature = \ + (OperandSignature::fromOpType(OperandType::kReg) | \ + OperandSignature::fromRegType(kType) | \ + OperandSignature::fromRegGroup(kGroup) | \ + OperandSignature::fromSize(kSize)).bits(); \ + \ +} + +//! Adds constructors and member functions to a class that implements abstract register. Abstract register is register +//! that doesn't have type or signature yet, it's a base class like `x86::Reg` or `arm::Reg`. +#define ASMJIT_DEFINE_ABSTRACT_REG(REG, BASE) \ +public: \ + /*! Default constructor that only setups basics. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG() noexcept \ + : BASE(Signature{kSignature}, kIdBad) {} \ + \ + /*! Makes a copy of the `other` register operand. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG(const REG& other) noexcept \ + : BASE(other) {} \ + \ + /*! Makes a copy of the `other` register having id set to `id` */ \ + ASMJIT_INLINE_NODEBUG constexpr REG(const BaseReg& other, uint32_t id) noexcept \ + : BASE(other, id) {} \ + \ + /*! Creates a register based on `signature` and `id`. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG(const OperandSignature& sgn, uint32_t id) noexcept \ + : BASE(sgn, id) {} \ + \ + /*! Creates a completely uninitialized REG register operand (garbage). */ \ + ASMJIT_INLINE_NODEBUG explicit REG(Globals::NoInit_) noexcept \ + : BASE(Globals::NoInit) {} \ + \ + /*! Creates a new register from register type and id. */ \ + static ASMJIT_INLINE_NODEBUG REG fromTypeAndId(RegType type, uint32_t id) noexcept { \ + return REG(signatureOf(type), id); \ + } \ + \ + /*! Clones the register operand. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG clone() const noexcept { return REG(*this); } \ + \ + ASMJIT_INLINE_NODEBUG REG& operator=(const REG& other) noexcept = default; + +//! Adds constructors and member functions to a class that implements final register. Final registers MUST HAVE a valid +//! signature. +#define ASMJIT_DEFINE_FINAL_REG(REG, BASE, TRAITS) \ +public: \ + static constexpr RegType kThisType = TRAITS::kType; \ + static constexpr RegGroup kThisGroup = TRAITS::kGroup; \ + static constexpr uint32_t kThisSize = TRAITS::kSize; \ + static constexpr uint32_t kSignature = TRAITS::kSignature; \ + \ + ASMJIT_DEFINE_ABSTRACT_REG(REG, BASE) \ + \ + /*! Creates a register operand having its id set to `id`. */ \ + ASMJIT_INLINE_NODEBUG constexpr explicit REG(uint32_t id) noexcept \ + : BASE(Signature{kSignature}, id) {} +//! \endcond + +//! List of physical registers (base). +//! +//! \note List of registers is only used by some ARM instructions at the moment. +class BaseRegList : public Operand { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + kSignature = Signature::fromOpType(OperandType::kRegList).bits() + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a dummy register operand. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kRegList), 0, 0, 0) {} + + //! Creates a new register operand which is the same as `other` . + ASMJIT_INLINE_NODEBUG constexpr BaseRegList(const BaseRegList& other) noexcept + : Operand(other) {} + + //! Creates a new register operand compatible with `other`, but with a different `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList(const BaseRegList& other, RegMask regMask) noexcept + : Operand(Globals::Init, other._signature, regMask, 0, 0) {} + + //! Creates a register initialized to the given `signature` and `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList(const Signature& signature, RegMask regMask) noexcept + : Operand(Globals::Init, signature, regMask, 0, 0) {} + + ASMJIT_INLINE_NODEBUG explicit BaseRegList(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG BaseRegList& operator=(const BaseRegList& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the register-list is valid, which means it has a type and at least a single register in the list. + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return bool(unsigned(_signature != 0u) & unsigned(_baseId != 0u)); } + + //! Tests whether the register type matches `type` - same as `isReg(type)`, provided for convenience. + ASMJIT_INLINE_NODEBUG constexpr bool isType(RegType type) const noexcept { return _signature.subset(Signature::kRegTypeMask) == Signature::fromRegType(type); } + //! Tests whether the register group matches `group`. + ASMJIT_INLINE_NODEBUG constexpr bool isGroup(RegGroup group) const noexcept { return _signature.subset(Signature::kRegGroupMask) == Signature::fromRegGroup(group); } + + //! Tests whether the register is a general purpose register (any size). + ASMJIT_INLINE_NODEBUG constexpr bool isGp() const noexcept { return isGroup(RegGroup::kGp); } + //! Tests whether the register is a vector register. + ASMJIT_INLINE_NODEBUG constexpr bool isVec() const noexcept { return isGroup(RegGroup::kVec); } + + //! Returns the register type. + ASMJIT_INLINE_NODEBUG constexpr RegType type() const noexcept { return _signature.regType(); } + //! Returns the register group. + ASMJIT_INLINE_NODEBUG constexpr RegGroup group() const noexcept { return _signature.regGroup(); } + //! Returns the size of a single register in this register-list or 0 if unspecified. + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } + + //! Returns the register list as a mask, where each bit represents one physical register. + ASMJIT_INLINE_NODEBUG constexpr RegMask list() const noexcept { return _baseId; } + //! Sets the register list to `mask`. + ASMJIT_INLINE_NODEBUG void setList(RegMask mask) noexcept { _baseId = mask; } + //! Remoes all registers from the register-list by making the underlying register-mask zero. + ASMJIT_INLINE_NODEBUG void resetList() noexcept { _baseId = 0; } + + //! Adds registers passed by a register `mask` to the register-list. + ASMJIT_INLINE_NODEBUG void addList(RegMask mask) noexcept { _baseId |= mask; } + //! Removes registers passed by a register `mask` to the register-list. + ASMJIT_INLINE_NODEBUG void clearList(RegMask mask) noexcept { _baseId &= ~mask; } + //! Uses AND operator to combine the current register-list with other register `mask`. + ASMJIT_INLINE_NODEBUG void andList(RegMask mask) noexcept { _baseId &= mask; } + //! Uses XOR operator to combine the current register-list with other register `mask`. + ASMJIT_INLINE_NODEBUG void xorList(RegMask mask) noexcept { _baseId ^= mask; } + + //! Checks whether a physical register `physId` is in the register-list. + ASMJIT_INLINE_NODEBUG bool hasReg(uint32_t physId) const noexcept { return physId < 32u ? (_baseId & (1u << physId)) != 0 : false; } + //! Adds a physical register `physId` to the register-list. + ASMJIT_INLINE_NODEBUG void addReg(uint32_t physId) noexcept { addList(1u << physId); } + //! Removes a physical register `physId` from the register-list. + ASMJIT_INLINE_NODEBUG void clearReg(uint32_t physId) noexcept { clearList(1u << physId); } + + //! Clones the register-list operand. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList clone() const noexcept { return BaseRegList(*this); } + + //! Casts this register to `RegT` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegListT cloneAs() const noexcept { return RegListT(Signature(RegListT::kSignature), list()); } + + //! Casts this register to `other` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegListT cloneAs(const RegListT& other) const noexcept { return RegListT(other.signature(), list()); } + + //! \} +}; + +template +class RegListT : public BaseRegList { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates a dummy register operand. + ASMJIT_INLINE_NODEBUG constexpr RegListT() noexcept + : BaseRegList() {} + + //! Creates a new register operand which is the same as `other` . + ASMJIT_INLINE_NODEBUG constexpr RegListT(const RegListT& other) noexcept + : BaseRegList(other) {} + + //! Creates a new register operand compatible with `other`, but with a different `id`. + ASMJIT_INLINE_NODEBUG constexpr RegListT(const RegListT& other, RegMask regMask) noexcept + : BaseRegList(other, regMask) {} + + //! Creates a register initialized to the given `signature` and `id`. + ASMJIT_INLINE_NODEBUG constexpr RegListT(const Signature& signature, RegMask regMask) noexcept + : BaseRegList(signature, regMask) {} + + //! Creates a register initialized to the given `signature` and `regs`. + ASMJIT_INLINE_NODEBUG RegListT(const Signature& signature, std::initializer_list regs) noexcept + : BaseRegList(signature, RegMask(0)) { addRegs(regs); } + + ASMJIT_INLINE_NODEBUG explicit RegListT(Globals::NoInit_) noexcept + : BaseRegList(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG RegListT& operator=(const RegListT& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + using BaseRegList::addList; + using BaseRegList::clearList; + using BaseRegList::andList; + using BaseRegList::xorList; + + //! Adds registers to this register-list as provided by `other` register-list. + ASMJIT_INLINE_NODEBUG void addList(const RegListT& other) noexcept { addList(other.list()); } + //! Removes registers contained in `other` register-list. + ASMJIT_INLINE_NODEBUG void clearList(const RegListT& other) noexcept { clearList(other.list()); } + //! Uses AND operator to combine the current register-list with `other` register-list. + ASMJIT_INLINE_NODEBUG void andList(const RegListT& other) noexcept { andList(other.list()); } + //! Uses XOR operator to combine the current register-list with `other` register-list. + ASMJIT_INLINE_NODEBUG void xorList(const RegListT& other) noexcept { xorList(other.list()); } + + using BaseRegList::addReg; + using BaseRegList::clearReg; + + ASMJIT_INLINE_NODEBUG void addReg(const RegT& reg) noexcept { + if (reg.id() < 32u) + addReg(reg.id()); + } + + ASMJIT_INLINE_NODEBUG void addRegs(std::initializer_list regs) noexcept { + for (const RegT& reg : regs) + addReg(reg); + } + + ASMJIT_INLINE_NODEBUG void clearReg(const RegT& reg) noexcept { + if (reg.id() < 32u) + clearReg(reg.id()); + } + + ASMJIT_INLINE_NODEBUG void clearRegs(std::initializer_list regs) noexcept { + for (const RegT& reg : regs) + clearReg(reg); + } + + //! \} +}; + +//! Base class for all memory operands. +//! +//! The data is split into the following parts: +//! +//! - BASE - Base register or label - requires 36 bits total. 4 bits are used to encode the type of the BASE operand +//! (label vs. register type) and the remaining 32 bits define the BASE id, which can be a physical or virtual +//! register index. If BASE type is zero, which is never used as a register type and label doesn't use it as well +//! then BASE field contains a high DWORD of a possible 64-bit absolute address, which is possible on X64. +//! +//! - INDEX - Index register (or theoretically Label, which doesn't make sense). Encoding is similar to BASE - it +//! also requires 36 bits and splits the encoding to INDEX type (4 bits defining the register type) and 32-bit id. +//! +//! - OFFSET - A relative offset of the address. Basically if BASE is specified the relative displacement adjusts +//! BASE and an optional INDEX. if BASE is not specified then the OFFSET should be considered as ABSOLUTE address +//! (at least on X86). In that case its low 32 bits are stored in DISPLACEMENT field and the remaining high 32 +//! bits are stored in BASE. +//! +//! - OTHER - There is rest 8 bits that can be used for whatever purpose. For example \ref x86::Mem operand uses +//! these bits to store segment override prefix and index shift (or scale). +class BaseMem : public Operand { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates a default `BaseMem` operand, that points to [0]. + ASMJIT_INLINE_NODEBUG constexpr BaseMem() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kMem), 0, 0, 0) {} + + //! Creates a `BaseMem` operand that is a clone of `other`. + ASMJIT_INLINE_NODEBUG constexpr BaseMem(const BaseMem& other) noexcept + : Operand(other) {} + + //! Creates a `BaseMem` operand from `baseReg` and `offset`. + //! + //! \note This is an architecture independent constructor that can be used to create an architecture + //! independent memory operand to be used in portable code that can handle multiple architectures. + ASMJIT_INLINE_NODEBUG constexpr explicit BaseMem(const BaseReg& baseReg, int32_t offset = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kMem) | Signature::fromMemBaseType(baseReg.type()), + baseReg.id(), + 0, + uint32_t(offset)) {} + + //! \cond INTERNAL + //! Creates a `BaseMem` operand from 4 integers as used by `Operand_` struct. + ASMJIT_INLINE_NODEBUG constexpr BaseMem(const OperandSignature& u0, uint32_t baseId, uint32_t indexId, int32_t offset) noexcept + : Operand(Globals::Init, u0, baseId, indexId, uint32_t(offset)) {} + //! \endcond + + //! Creates a completely uninitialized `BaseMem` operand. + ASMJIT_INLINE_NODEBUG explicit BaseMem(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! Resets the memory operand - after the reset the memory points to [0]. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _signature = Signature::fromOpType(OperandType::kMem); + _baseId = 0; + _data[0] = 0; + _data[1] = 0; + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG BaseMem& operator=(const BaseMem& other) noexcept { copyFrom(other); return *this; } + + //! \} + + //! \name Accessors + //! \{ + + //! Clones the memory operand. + ASMJIT_INLINE_NODEBUG constexpr BaseMem clone() const noexcept { return BaseMem(*this); } + + //! Creates a new copy of this memory operand adjusted by `off`. + ASMJIT_INLINE_NODEBUG BaseMem cloneAdjusted(int64_t off) const noexcept { + BaseMem result(*this); + result.addOffset(off); + return result; + } + + //! Tests whether this memory operand is a register home (only used by \ref asmjit_compiler) + ASMJIT_INLINE_NODEBUG constexpr bool isRegHome() const noexcept { return _signature.hasField(); } + //! Mark this memory operand as register home (only used by \ref asmjit_compiler). + ASMJIT_INLINE_NODEBUG void setRegHome() noexcept { _signature |= Signature::kMemRegHomeFlag; } + //! Marks this operand to not be a register home (only used by \ref asmjit_compiler). + ASMJIT_INLINE_NODEBUG void clearRegHome() noexcept { _signature &= ~Signature::kMemRegHomeFlag; } + + //! Tests whether the memory operand has a BASE register or label specified. + ASMJIT_INLINE_NODEBUG constexpr bool hasBase() const noexcept { + return (_signature & Signature::kMemBaseTypeMask) != 0; + } + + //! Tests whether the memory operand has an INDEX register specified. + ASMJIT_INLINE_NODEBUG constexpr bool hasIndex() const noexcept { + return (_signature & Signature::kMemIndexTypeMask) != 0; + } + + //! Tests whether the memory operand has BASE or INDEX register. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseOrIndex() const noexcept { + return (_signature & Signature::kMemBaseIndexMask) != 0; + } + + //! Tests whether the memory operand has BASE and INDEX register. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseAndIndex() const noexcept { + return (_signature & Signature::kMemBaseTypeMask) != 0 && (_signature & Signature::kMemIndexTypeMask) != 0; + } + + //! Tests whether the BASE operand is a label. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseLabel() const noexcept { + return _signature.subset(Signature::kMemBaseTypeMask) == Signature::fromMemBaseType(RegType::kLabelTag); + } + + //! Tests whether the BASE operand is a register (registers start after `RegType::kLabelTag`). + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseReg() const noexcept { + return _signature.subset(Signature::kMemBaseTypeMask).bits() > Signature::fromMemBaseType(RegType::kLabelTag).bits(); + } + + //! Tests whether the INDEX operand is a register (registers start after `RegType::kLabelTag`). + ASMJIT_INLINE_NODEBUG constexpr bool hasIndexReg() const noexcept { + return _signature.subset(Signature::kMemIndexTypeMask).bits() > Signature::fromMemIndexType(RegType::kLabelTag).bits(); + } + + //! Returns the type of the BASE register (0 if this memory operand doesn't use the BASE register). + //! + //! \note If the returned type is one (a value never associated to a register type) the BASE is not register, but it + //! is a label. One equals to `kLabelTag`. You should always check `hasBaseLabel()` before using `baseId()` result. + ASMJIT_INLINE_NODEBUG constexpr RegType baseType() const noexcept { return _signature.memBaseType(); } + + //! Returns the type of an INDEX register (0 if this memory operand doesn't + //! use the INDEX register). + ASMJIT_INLINE_NODEBUG constexpr RegType indexType() const noexcept { return _signature.memIndexType(); } + + //! This is used internally for BASE+INDEX validation. + ASMJIT_INLINE_NODEBUG constexpr uint32_t baseAndIndexTypes() const noexcept { return _signature.getField(); } + + //! Returns both BASE (4:0 bits) and INDEX (9:5 bits) types combined into a single value. + //! + //! \remarks Returns id of the BASE register or label (if the BASE was specified as label). + ASMJIT_INLINE_NODEBUG constexpr uint32_t baseId() const noexcept { return _baseId; } + + //! Returns the id of the INDEX register. + ASMJIT_INLINE_NODEBUG constexpr uint32_t indexId() const noexcept { return _data[kDataMemIndexId]; } + + //! Sets the id of the BASE register (without modifying its type). + ASMJIT_INLINE_NODEBUG void setBaseId(uint32_t id) noexcept { _baseId = id; } + //! Sets the register type of the BASE register (without modifying its id). + ASMJIT_INLINE_NODEBUG void setBaseType(RegType regType) noexcept { _signature.setMemBaseType(regType); } + + //! Sets the id of the INDEX register (without modifying its type). + ASMJIT_INLINE_NODEBUG void setIndexId(uint32_t id) noexcept { _data[kDataMemIndexId] = id; } + //! Sets the register type of the INDEX register (without modifying its id). + ASMJIT_INLINE_NODEBUG void setIndexType(RegType regType) noexcept { _signature.setMemIndexType(regType); } + + //! Sets the base register to type and id of the given `base` operand. + ASMJIT_INLINE_NODEBUG void setBase(const BaseReg& base) noexcept { return _setBase(base.type(), base.id()); } + //! Sets the index register to type and id of the given `index` operand. + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index) noexcept { return _setIndex(index.type(), index.id()); } + + //! \cond INTERNAL + ASMJIT_INLINE_NODEBUG void _setBase(RegType type, uint32_t id) noexcept { + _signature.setField(uint32_t(type)); + _baseId = id; + } + + ASMJIT_INLINE_NODEBUG void _setIndex(RegType type, uint32_t id) noexcept { + _signature.setField(uint32_t(type)); + _data[kDataMemIndexId] = id; + } + //! \endcond + + //! Resets the memory operand's BASE register or label. + ASMJIT_INLINE_NODEBUG void resetBase() noexcept { _setBase(RegType::kNone, 0); } + //! Resets the memory operand's INDEX register. + ASMJIT_INLINE_NODEBUG void resetIndex() noexcept { _setIndex(RegType::kNone, 0); } + + //! Sets the memory operand size (in bytes). + ASMJIT_INLINE_NODEBUG void setSize(uint32_t size) noexcept { _signature.setField(size); } + + //! Tests whether the memory operand has a 64-bit offset or absolute address. + //! + //! If this is true then `hasBase()` must always report false. + ASMJIT_INLINE_NODEBUG constexpr bool isOffset64Bit() const noexcept { return baseType() == RegType::kNone; } + + //! Tests whether the memory operand has a non-zero offset or absolute address. + ASMJIT_INLINE_NODEBUG constexpr bool hasOffset() const noexcept { + return (_data[kDataMemOffsetLo] | uint32_t(_baseId & Support::bitMaskFromBool(isOffset64Bit()))) != 0; + } + + //! Returns either relative offset or absolute address as 64-bit integer. + ASMJIT_INLINE_NODEBUG constexpr int64_t offset() const noexcept { + return isOffset64Bit() ? int64_t(uint64_t(_data[kDataMemOffsetLo]) | (uint64_t(_baseId) << 32)) + : int64_t(int32_t(_data[kDataMemOffsetLo])); // Sign extend 32-bit offset. + } + + //! Returns a 32-bit low part of a 64-bit offset or absolute address. + ASMJIT_INLINE_NODEBUG constexpr int32_t offsetLo32() const noexcept { return int32_t(_data[kDataMemOffsetLo]); } + //! Returns a 32-but high part of a 64-bit offset or absolute address. + //! + //! \note This function is UNSAFE and returns garbage if `isOffset64Bit()` + //! returns false. Never use it blindly without checking it first. + ASMJIT_INLINE_NODEBUG constexpr int32_t offsetHi32() const noexcept { return int32_t(_baseId); } + + //! Sets a 64-bit offset or an absolute address to `offset`. + //! + //! \note This functions attempts to set both high and low parts of a 64-bit offset, however, if the operand has + //! a BASE register it will store only the low 32 bits of the offset / address as there is no way to store both + //! BASE and 64-bit offset, and there is currently no architecture that has such capability targeted by AsmJit. + inline void setOffset(int64_t offset) noexcept { + uint32_t lo = uint32_t(uint64_t(offset) & 0xFFFFFFFFu); + uint32_t hi = uint32_t(uint64_t(offset) >> 32); + uint32_t hiMsk = Support::bitMaskFromBool(isOffset64Bit()); + + _data[kDataMemOffsetLo] = lo; + _baseId = (hi & hiMsk) | (_baseId & ~hiMsk); + } + //! Sets a low 32-bit offset to `offset` (don't use without knowing how BaseMem works). + inline void setOffsetLo32(int32_t offset) noexcept { _data[kDataMemOffsetLo] = uint32_t(offset); } + + //! Adjusts the offset by `offset`. + //! + //! \note This is a fast function that doesn't use the HI 32-bits of a 64-bit offset. Use it only if you know that + //! there is a BASE register and the offset is only 32 bits anyway. + + //! Adjusts the memory operand offset by a `offset`. + inline void addOffset(int64_t offset) noexcept { + if (isOffset64Bit()) { + int64_t result = offset + int64_t(uint64_t(_data[kDataMemOffsetLo]) | (uint64_t(_baseId) << 32)); + _data[kDataMemOffsetLo] = uint32_t(uint64_t(result) & 0xFFFFFFFFu); + _baseId = uint32_t(uint64_t(result) >> 32); + } + else { + _data[kDataMemOffsetLo] += uint32_t(uint64_t(offset) & 0xFFFFFFFFu); + } + } + + //! Adds `offset` to a low 32-bit offset part (don't use without knowing how BaseMem works). + ASMJIT_INLINE_NODEBUG void addOffsetLo32(int32_t offset) noexcept { _data[kDataMemOffsetLo] += uint32_t(offset); } + + //! Resets the memory offset to zero. + ASMJIT_INLINE_NODEBUG void resetOffset() noexcept { setOffset(0); } + + //! Resets the lo part of the memory offset to zero (don't use without knowing how BaseMem works). + ASMJIT_INLINE_NODEBUG void resetOffsetLo32() noexcept { setOffsetLo32(0); } + + //! \} +}; + +//! Type of the an immediate value. +enum class ImmType : uint32_t { + //! Immediate is integer. + kInt = 0, + //! Immediate is a floating point stored as double-precision. + kDouble = 1 +}; + +//! Immediate operands are encoded with instruction data. +class Imm : public Operand { +public: + //! \cond INTERNAL + template + struct IsConstexprConstructibleAsImmType + : public std::integral_constant::value || + std::is_pointer::value || + std::is_integral::value || + std::is_function::value> {}; + + template + struct IsConvertibleToImmType + : public std::integral_constant::value || + std::is_floating_point::value> {}; + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Creates a new immediate value (initial value is 0). + ASMJIT_INLINE_NODEBUG constexpr Imm() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kImm), 0, 0, 0) {} + + //! Creates a new immediate value from `other`. + ASMJIT_INLINE_NODEBUG constexpr Imm(const Imm& other) noexcept + : Operand(other) {} + + //! Creates a new immediate value from ARM/AArch64 specific `shift`. + ASMJIT_INLINE_NODEBUG constexpr Imm(const arm::Shift& shift) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(uint32_t(shift.op())), + 0, + Support::unpackU32At0(shift.value()), + Support::unpackU32At1(shift.value())) {} + + //! Creates a new signed immediate value, assigning the value to `val` and an architecture-specific predicate + //! to `predicate`. + //! + //! \note Predicate is currently only used by ARM architectures. + template::type>::value>::type> + ASMJIT_INLINE_NODEBUG constexpr Imm(const T& val, const uint32_t predicate = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(predicate), + 0, + Support::unpackU32At0(int64_t(val)), + Support::unpackU32At1(int64_t(val))) {} + + ASMJIT_INLINE_NODEBUG Imm(const float& val, const uint32_t predicate = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(predicate), + 0, + 0, + 0) { setValue(val); } + + ASMJIT_INLINE_NODEBUG Imm(const double& val, const uint32_t predicate = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(predicate), + 0, + 0, + 0) { setValue(val); } + + ASMJIT_INLINE_NODEBUG explicit Imm(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Assigns the value of the `other` operand to this immediate. + ASMJIT_INLINE_NODEBUG Imm& operator=(const Imm& other) noexcept { copyFrom(other); return *this; } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns immediate type. + ASMJIT_INLINE_NODEBUG constexpr ImmType type() const noexcept { return (ImmType)_signature.getField(); } + //! Sets the immediate type to `type`. + ASMJIT_INLINE_NODEBUG void setType(ImmType type) noexcept { _signature.setField(uint32_t(type)); } + //! Resets immediate type to \ref ImmType::kInt. + ASMJIT_INLINE_NODEBUG void resetType() noexcept { setType(ImmType::kInt); } + + //! Returns operation predicate of the immediate. + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the immediate. + ASMJIT_INLINE_NODEBUG constexpr uint32_t predicate() const noexcept { return _signature.getField(); } + + //! Sets operation predicate of the immediate to `predicate`. + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the immediate. + ASMJIT_INLINE_NODEBUG void setPredicate(uint32_t predicate) noexcept { _signature.setField(predicate); } + + //! Resets the shift operation type of the immediate to the default value (no operation). + ASMJIT_INLINE_NODEBUG void resetPredicate() noexcept { _signature.setField(0); } + + //! Returns the immediate value as `int64_t`, which is the internal format Imm uses. + ASMJIT_INLINE_NODEBUG constexpr int64_t value() const noexcept { + return int64_t((uint64_t(_data[kDataImmValueHi]) << 32) | _data[kDataImmValueLo]); + } + + //! Tests whether this immediate value is integer of any size. + ASMJIT_INLINE_NODEBUG constexpr uint32_t isInt() const noexcept { return type() == ImmType::kInt; } + //! Tests whether this immediate value is a double precision floating point value. + ASMJIT_INLINE_NODEBUG constexpr uint32_t isDouble() const noexcept { return type() == ImmType::kDouble; } + + //! Tests whether the immediate can be casted to 8-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr bool isInt8() const noexcept { return type() == ImmType::kInt && Support::isInt8(value()); } + //! Tests whether the immediate can be casted to 8-bit unsigned integer. + ASMJIT_INLINE_NODEBUG constexpr bool isUInt8() const noexcept { return type() == ImmType::kInt && Support::isUInt8(value()); } + //! Tests whether the immediate can be casted to 16-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr bool isInt16() const noexcept { return type() == ImmType::kInt && Support::isInt16(value()); } + //! Tests whether the immediate can be casted to 16-bit unsigned integer. + ASMJIT_INLINE_NODEBUG constexpr bool isUInt16() const noexcept { return type() == ImmType::kInt && Support::isUInt16(value()); } + //! Tests whether the immediate can be casted to 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr bool isInt32() const noexcept { return type() == ImmType::kInt && Support::isInt32(value()); } + //! Tests whether the immediate can be casted to 32-bit unsigned integer. + ASMJIT_INLINE_NODEBUG constexpr bool isUInt32() const noexcept { return type() == ImmType::kInt && _data[kDataImmValueHi] == 0; } + + //! Returns the immediate value casted to `T`. + //! + //! The value is masked before it's casted to `T` so the returned value is simply the representation of `T` + //! considering the original value's lowest bits. + template + ASMJIT_INLINE_NODEBUG T valueAs() const noexcept { return Support::immediateToT(value()); } + + //! Returns low 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr int32_t int32Lo() const noexcept { return int32_t(_data[kDataImmValueLo]); } + //! Returns high 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr int32_t int32Hi() const noexcept { return int32_t(_data[kDataImmValueHi]); } + //! Returns low 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr uint32_t uint32Lo() const noexcept { return _data[kDataImmValueLo]; } + //! Returns high 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr uint32_t uint32Hi() const noexcept { return _data[kDataImmValueHi]; } + + //! Sets immediate value to `val`, the value is casted to a signed 64-bit integer. + template + ASMJIT_INLINE_NODEBUG void setValue(const T& val) noexcept { + _setValueInternal(Support::immediateFromT(val), std::is_floating_point::value ? ImmType::kDouble : ImmType::kInt); + } + + ASMJIT_INLINE_NODEBUG void _setValueInternal(int64_t val, ImmType type) noexcept { + setType(type); + _data[kDataImmValueHi] = uint32_t(uint64_t(val) >> 32); + _data[kDataImmValueLo] = uint32_t(uint64_t(val) & 0xFFFFFFFFu); + } + + //! \} + + //! \name Utilities + //! \{ + + //! Clones the immediate operand. + ASMJIT_INLINE_NODEBUG constexpr Imm clone() const noexcept { return Imm(*this); } + + ASMJIT_INLINE_NODEBUG void signExtend8Bits() noexcept { setValue(int64_t(valueAs())); } + ASMJIT_INLINE_NODEBUG void signExtend16Bits() noexcept { setValue(int64_t(valueAs())); } + ASMJIT_INLINE_NODEBUG void signExtend32Bits() noexcept { setValue(int64_t(valueAs())); } + + ASMJIT_INLINE_NODEBUG void zeroExtend8Bits() noexcept { setValue(valueAs()); } + ASMJIT_INLINE_NODEBUG void zeroExtend16Bits() noexcept { setValue(valueAs()); } + ASMJIT_INLINE_NODEBUG void zeroExtend32Bits() noexcept { _data[kDataImmValueHi] = 0u; } + + //! \} +}; + +//! Creates a new immediate operand. +template +static ASMJIT_INLINE_NODEBUG constexpr Imm imm(const T& val) noexcept { return Imm(val); } + +//! \} + +namespace Globals { + //! \ingroup asmjit_assembler + //! + //! A default-constructed operand of `Operand_::kOpNone` type. + static constexpr const Operand none; +} + +//! \cond INTERNAL +namespace Support { + +template +struct ForwardOpImpl { + static ASMJIT_INLINE_NODEBUG const T& forward(const T& value) noexcept { return value; } +}; + +template +struct ForwardOpImpl { + static ASMJIT_INLINE_NODEBUG Imm forward(const T& value) noexcept { return Imm(value); } +}; + +//! Either forwards operand T or returns a new operand that wraps it if T is a type convertible to operand. +//! At the moment this is only used to convert integers, floats, and enumarations to \ref Imm operands. +template +struct ForwardOp : public ForwardOpImpl::type>::value> {}; + +} // {Support} +//! \endcond + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_OPERAND_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/osutils.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/osutils.h new file mode 100644 index 0000000000000000000000000000000000000000..c6588373425f749e601259c384d2d3a46ee64ed6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/osutils.h @@ -0,0 +1,54 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_OSUTILS_H_INCLUDED +#define ASMJIT_CORE_OSUTILS_H_INCLUDED + +#include "../core/globals.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! \cond INTERNAL +//! Lock. +//! +//! Lock is internal, it cannot be used outside of AsmJit, however, its internal +//! layout is exposed as it's used by some other classes, which are public. +class Lock { +public: + ASMJIT_NONCOPYABLE(Lock) + +#if defined(_WIN32) +#pragma pack(push, 8) + struct ASMJIT_MAY_ALIAS Handle { + void* DebugInfo; + long LockCount; + long RecursionCount; + void* OwningThread; + void* LockSemaphore; + unsigned long* SpinCount; + }; + Handle _handle; +#pragma pack(pop) +#elif !defined(__EMSCRIPTEN__) + typedef pthread_mutex_t Handle; + Handle _handle; +#endif + + ASMJIT_INLINE_NODEBUG Lock() noexcept; + ASMJIT_INLINE_NODEBUG ~Lock() noexcept; + + ASMJIT_INLINE_NODEBUG void lock() noexcept; + ASMJIT_INLINE_NODEBUG void unlock() noexcept; +}; +//! \endcond + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_OSUTILS_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/string.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/string.h new file mode 100644 index 0000000000000000000000000000000000000000..c4dee14b00efa5364f2a33e0d4d26d3bf4058355 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/string.h @@ -0,0 +1,383 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_STRING_H_INCLUDED +#define ASMJIT_CORE_STRING_H_INCLUDED + +#include "../core/support.h" +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! Format flags used by \ref String API. +enum class StringFormatFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + //! Show sign. + kShowSign = 0x00000001u, + //! Show space. + kShowSpace = 0x00000002u, + //! Alternate form (use 0x when formatting HEX number). + kAlternate = 0x00000004u, + //! The input is signed. + kSigned = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(StringFormatFlags) + +//! Fixed string - only useful for strings that would never exceed `N - 1` characters; always null-terminated. +template +union FixedString { + //! \name Constants + //! \{ + + // This cannot be constexpr as GCC 4.8 refuses constexpr members of unions. + enum : uint32_t { + kNumUInt32Words = uint32_t((N + sizeof(uint32_t) - 1) / sizeof(uint32_t)) + }; + + //! \} + + //! \name Members + //! \{ + + char str[kNumUInt32Words * sizeof(uint32_t)]; + uint32_t u32[kNumUInt32Words]; + + //! \} + + //! \name Utilities + //! \{ + + inline bool equals(const char* other) const noexcept { return strcmp(str, other) == 0; } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use FixedString::equals() instead") + inline bool eq(const char* other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} +}; + +//! A simple non-reference counted string that uses small string optimization (SSO). +//! +//! This string has 3 allocation possibilities: +//! +//! 1. Small - embedded buffer is used for up to `kSSOCapacity` characters. This should handle most small +//! strings and thus avoid dynamic memory allocation for most use-cases. +//! +//! 2. Large - string that doesn't fit into an embedded buffer (or string that was truncated from a larger +//! buffer) and is owned by AsmJit. When you destroy the string AsmJit would automatically +//! release the large buffer. +//! +//! 3. External - like Large (2), however, the large buffer is not owned by AsmJit and won't be released when +//! the string is destroyed or reallocated. This is mostly useful for working with larger temporary +//! strings allocated on stack or with immutable strings. +class String { +public: + ASMJIT_NONCOPYABLE(String) + + //! String operation. + enum class ModifyOp : uint32_t { + //! Assignment - a new content replaces the current one. + kAssign = 0, + //! Append - a new content is appended to the string. + kAppend = 1 + }; + + //! \cond INTERNAL + enum : uint32_t { + kLayoutSize = 32, + kSSOCapacity = kLayoutSize - 2 + }; + + //! String type. + enum Type : uint8_t { + //! Large string (owned by String). + kTypeLarge = 0x1Fu, + //! External string (zone allocated or not owned by String). + kTypeExternal = 0x20u + }; + + union Raw { + uint8_t u8[kLayoutSize]; + uint64_t u64[kLayoutSize / sizeof(uint64_t)]; + uintptr_t uptr[kLayoutSize / sizeof(uintptr_t)]; + }; + + struct Small { + uint8_t type; + char data[kSSOCapacity + 1u]; + }; + + struct Large { + uint8_t type; + uint8_t reserved[sizeof(uintptr_t) - 1]; + size_t size; + size_t capacity; + char* data; + }; + + union { + uint8_t _type; + Raw _raw; + Small _small; + Large _large; + }; + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Creates a default-initialized string if zero length. + ASMJIT_INLINE_NODEBUG String() noexcept + : _small {} {} + + //! Creates a string that takes ownership of the content of the `other` string. + ASMJIT_INLINE_NODEBUG String(String&& other) noexcept { + _raw = other._raw; + other._resetInternal(); + } + + ASMJIT_INLINE_NODEBUG ~String() noexcept { + reset(); + } + + //! Reset the string into a construction state. + ASMJIT_API Error reset() noexcept; + + //! \} + + //! \name Overloaded Operators + //! \{ + + inline String& operator=(String&& other) noexcept { + swap(other); + other.reset(); + return *this; + } + + ASMJIT_INLINE_NODEBUG bool operator==(const char* other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const char* other) const noexcept { return !equals(other); } + + ASMJIT_INLINE_NODEBUG bool operator==(const String& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const String& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool isExternal() const noexcept { return _type == kTypeExternal; } + ASMJIT_INLINE_NODEBUG bool isLargeOrExternal() const noexcept { return _type >= kTypeLarge; } + + //! Tests whether the string is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return size() == 0; } + //! Returns the size of the string. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return isLargeOrExternal() ? size_t(_large.size) : size_t(_type); } + //! Returns the capacity of the string. + ASMJIT_INLINE_NODEBUG size_t capacity() const noexcept { return isLargeOrExternal() ? _large.capacity : size_t(kSSOCapacity); } + + //! Returns the data of the string. + ASMJIT_INLINE_NODEBUG char* data() noexcept { return isLargeOrExternal() ? _large.data : _small.data; } + //! \overload + ASMJIT_INLINE_NODEBUG const char* data() const noexcept { return isLargeOrExternal() ? _large.data : _small.data; } + + ASMJIT_INLINE_NODEBUG char* start() noexcept { return data(); } + ASMJIT_INLINE_NODEBUG const char* start() const noexcept { return data(); } + + ASMJIT_INLINE_NODEBUG char* end() noexcept { return data() + size(); } + ASMJIT_INLINE_NODEBUG const char* end() const noexcept { return data() + size(); } + + //! \} + + //! \name String Operations + //! \{ + + //! Swaps the content of this string with `other`. + ASMJIT_INLINE_NODEBUG void swap(String& other) noexcept { + std::swap(_raw, other._raw); + } + + //! Clears the content of the string. + ASMJIT_API Error clear() noexcept; + + ASMJIT_API char* prepare(ModifyOp op, size_t size) noexcept; + + ASMJIT_API Error _opString(ModifyOp op, const char* str, size_t size = SIZE_MAX) noexcept; + ASMJIT_API Error _opChar(ModifyOp op, char c) noexcept; + ASMJIT_API Error _opChars(ModifyOp op, char c, size_t n) noexcept; + ASMJIT_API Error _opNumber(ModifyOp op, uint64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept; + ASMJIT_API Error _opHex(ModifyOp op, const void* data, size_t size, char separator = '\0') noexcept; + ASMJIT_API Error _opFormat(ModifyOp op, const char* fmt, ...) noexcept; + ASMJIT_API Error _opVFormat(ModifyOp op, const char* fmt, va_list ap) noexcept; + + //! Replaces the current of the string with `data` of the given `size`. + //! + //! Null terminated strings can set `size` to `SIZE_MAX`. + ASMJIT_API Error assign(const char* data, size_t size = SIZE_MAX) noexcept; + + //! Replaces the current of the string with `other` string. + ASMJIT_INLINE_NODEBUG Error assign(const String& other) noexcept { + return assign(other.data(), other.size()); + } + + //! Replaces the current of the string by a single `c` character. + ASMJIT_INLINE_NODEBUG Error assign(char c) noexcept { + return _opChar(ModifyOp::kAssign, c); + } + + //! Replaces the current of the string by a `c` character, repeated `n` times. + ASMJIT_INLINE_NODEBUG Error assignChars(char c, size_t n) noexcept { + return _opChars(ModifyOp::kAssign, c, n); + } + + //! Replaces the current of the string by a formatted integer `i` (signed). + ASMJIT_INLINE_NODEBUG Error assignInt(int64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAssign, uint64_t(i), base, width, flags | StringFormatFlags::kSigned); + } + + //! Replaces the current of the string by a formatted integer `i` (unsigned). + ASMJIT_INLINE_NODEBUG Error assignUInt(uint64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAssign, i, base, width, flags); + } + + //! Replaces the current of the string by the given `data` converted to a HEX string. + ASMJIT_INLINE_NODEBUG Error assignHex(const void* data, size_t size, char separator = '\0') noexcept { + return _opHex(ModifyOp::kAssign, data, size, separator); + } + + //! Replaces the current of the string by a formatted string `fmt`. + template + ASMJIT_INLINE_NODEBUG Error assignFormat(const char* fmt, Args&&... args) noexcept { + return _opFormat(ModifyOp::kAssign, fmt, std::forward(args)...); + } + + //! Replaces the current of the string by a formatted string `fmt` (va_list version). + ASMJIT_INLINE_NODEBUG Error assignVFormat(const char* fmt, va_list ap) noexcept { + return _opVFormat(ModifyOp::kAssign, fmt, ap); + } + + //! Appends `str` having the given size `size` to the string. + //! + //! Null terminated strings can set `size` to `SIZE_MAX`. + ASMJIT_INLINE_NODEBUG Error append(const char* str, size_t size = SIZE_MAX) noexcept { + return _opString(ModifyOp::kAppend, str, size); + } + + //! Appends `other` string to this string. + ASMJIT_INLINE_NODEBUG Error append(const String& other) noexcept { + return append(other.data(), other.size()); + } + + //! Appends a single `c` character. + ASMJIT_INLINE_NODEBUG Error append(char c) noexcept { + return _opChar(ModifyOp::kAppend, c); + } + + //! Appends `c` character repeated `n` times. + ASMJIT_INLINE_NODEBUG Error appendChars(char c, size_t n) noexcept { + return _opChars(ModifyOp::kAppend, c, n); + } + + //! Appends a formatted integer `i` (signed). + ASMJIT_INLINE_NODEBUG Error appendInt(int64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAppend, uint64_t(i), base, width, flags | StringFormatFlags::kSigned); + } + + //! Appends a formatted integer `i` (unsigned). + ASMJIT_INLINE_NODEBUG Error appendUInt(uint64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAppend, i, base, width, flags); + } + + //! Appends the given `data` converted to a HEX string. + ASMJIT_INLINE_NODEBUG Error appendHex(const void* data, size_t size, char separator = '\0') noexcept { + return _opHex(ModifyOp::kAppend, data, size, separator); + } + + //! Appends a formatted string `fmt` with `args`. + template + ASMJIT_INLINE_NODEBUG Error appendFormat(const char* fmt, Args&&... args) noexcept { + return _opFormat(ModifyOp::kAppend, fmt, std::forward(args)...); + } + + //! Appends a formatted string `fmt` (va_list version). + ASMJIT_INLINE_NODEBUG Error appendVFormat(const char* fmt, va_list ap) noexcept { + return _opVFormat(ModifyOp::kAppend, fmt, ap); + } + + ASMJIT_API Error padEnd(size_t n, char c = ' ') noexcept; + + //! Truncate the string length into `newSize`. + ASMJIT_API Error truncate(size_t newSize) noexcept; + + ASMJIT_API bool equals(const char* other, size_t size = SIZE_MAX) const noexcept; + ASMJIT_INLINE_NODEBUG bool equals(const String& other) const noexcept { return equals(other.data(), other.size()); } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use String::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const char* other, size_t size = SIZE_MAX) const noexcept { return equals(other, size); } + + ASMJIT_DEPRECATED("Use String::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const String& other) const noexcept { return equals(other.data(), other.size()); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} + + //! \name Internal Functions + //! \{ + + //! Resets string to embedded and makes it empty (zero length, zero first char) + //! + //! \note This is always called internally after an external buffer was released as it zeroes all bytes + //! used by String's embedded storage. + inline void _resetInternal() noexcept { + for (size_t i = 0; i < ASMJIT_ARRAY_SIZE(_raw.uptr); i++) + _raw.uptr[i] = 0; + } + + inline void _setSize(size_t newSize) noexcept { + if (isLargeOrExternal()) + _large.size = newSize; + else + _small.type = uint8_t(newSize); + } + + //! \} +}; + +//! Temporary string builder, has statically allocated `N` bytes. +template +class StringTmp : public String { +public: + ASMJIT_NONCOPYABLE(StringTmp) + + //! Embedded data. + char _embeddedData[Support::alignUp(N + 1, sizeof(size_t))]; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG StringTmp() noexcept { + _resetToTemporary(); + } + + inline void _resetToTemporary() noexcept { + _large.type = kTypeExternal; + _large.capacity = ASMJIT_ARRAY_SIZE(_embeddedData) - 1; + _large.data = _embeddedData; + _embeddedData[0] = '\0'; + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_STRING_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/support.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/support.h new file mode 100644 index 0000000000000000000000000000000000000000..345cf8cf4cdcbc7e737ecc65a31e13427866631c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/support.h @@ -0,0 +1,1818 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_SUPPORT_H_INCLUDED +#define ASMJIT_CORE_SUPPORT_H_INCLUDED + +#include "../core/globals.h" + +#if defined(_MSC_VER) + #include +#endif + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! Contains support classes and functions that may be used by AsmJit source and header files. Anything defined +//! here is considered internal and should not be used outside of AsmJit and related projects like AsmTK. +namespace Support { + +// Support - Basic Traits +// ====================== + +#if ASMJIT_ARCH_X86 +typedef uint8_t FastUInt8; +#else +typedef uint32_t FastUInt8; +#endif + +//! \cond INTERNAL +namespace Internal { + template + struct AliasedUInt {}; + + template<> struct AliasedUInt { typedef uint16_t ASMJIT_MAY_ALIAS T; }; + template<> struct AliasedUInt { typedef uint32_t ASMJIT_MAY_ALIAS T; }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS T; }; + + template<> struct AliasedUInt { typedef uint16_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 1); }; + template<> struct AliasedUInt { typedef uint32_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 1); }; + template<> struct AliasedUInt { typedef uint32_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 2); }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 1); }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 2); }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 4); }; + + // StdInt - Make an int-type by size (signed or unsigned) that is the + // same as types defined by . + // Int32Or64 - Make an int-type that has at least 32 bits: [u]int[32|64]_t. + + template + struct StdInt {}; // Fail if not specialized. + + template<> struct StdInt<1, 0> { typedef int8_t Type; }; + template<> struct StdInt<1, 1> { typedef uint8_t Type; }; + template<> struct StdInt<2, 0> { typedef int16_t Type; }; + template<> struct StdInt<2, 1> { typedef uint16_t Type; }; + template<> struct StdInt<4, 0> { typedef int32_t Type; }; + template<> struct StdInt<4, 1> { typedef uint32_t Type; }; + template<> struct StdInt<8, 0> { typedef int64_t Type; }; + template<> struct StdInt<8, 1> { typedef uint64_t Type; }; + + template::value> + struct Int32Or64 : public StdInt {}; +} +//! \endcond + +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUnsigned() noexcept { return std::is_unsigned::value; } + +//! Casts an integer `x` to either `int32_t` or `int64_t` depending on `T`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::Int32Or64::Type asInt(const T& x) noexcept { + return (typename Internal::Int32Or64::Type)x; +} + +//! Casts an integer `x` to either `uint32_t` or `uint64_t` depending on `T`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::Int32Or64::Type asUInt(const T& x) noexcept { + return (typename Internal::Int32Or64::Type)x; +} + +//! Casts an integer `x` to either `int32_t`, uint32_t`, `int64_t`, or `uint64_t` depending on `T`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::Int32Or64::Type asNormalized(const T& x) noexcept { + return (typename Internal::Int32Or64::Type)x; +} + +//! Casts an integer `x` to the same type as defined by ``. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::StdInt()>::Type asStdInt(const T& x) noexcept { + return (typename Internal::StdInt()>::Type)x; +} + +//! A helper class that can be used to iterate over enum values. +template +struct EnumValues { + typedef typename std::underlying_type::type ValueType; + + struct Iterator { + ValueType value; + + ASMJIT_INLINE_NODEBUG T operator*() const { return (T)value; } + ASMJIT_INLINE_NODEBUG void operator++() { ++value; } + + ASMJIT_INLINE_NODEBUG bool operator==(const Iterator& other) const noexcept { return value == other.value; } + ASMJIT_INLINE_NODEBUG bool operator!=(const Iterator& other) const noexcept { return value != other.value; } + }; + + ASMJIT_INLINE_NODEBUG Iterator begin() const noexcept { return Iterator{ValueType(from)}; } + ASMJIT_INLINE_NODEBUG Iterator end() const noexcept { return Iterator{ValueType(to) + 1}; } +}; + +// Support - BitCast +// ================= + +//! \cond +namespace Internal { + template + union BitCastUnion { + ASMJIT_INLINE_NODEBUG BitCastUnion(SrcT src) noexcept : src(src) {} + SrcT src; + DstT dst; + }; +} +//! \endcond + +//! Bit-casts from `Src` type to `Dst` type. +//! +//! Useful to bit-cast between integers and floating points. +template +static ASMJIT_INLINE_NODEBUG Dst bitCast(const Src& x) noexcept { return Internal::BitCastUnion(x).dst; } + +// Support - BitOps +// ================ + +//! Storage used to store a pack of bits (should by compatible with a machine word). +typedef Internal::StdInt::Type BitWord; + +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bitSizeOf() noexcept { return uint32_t(sizeof(T) * 8u); } + +//! Number of bits stored in a single `BitWord`. +static constexpr uint32_t kBitWordSizeInBits = bitSizeOf(); + +//! Returns `0 - x` in a safe way (no undefined behavior), works for unsigned numbers as well. +template +static ASMJIT_INLINE_NODEBUG constexpr T neg(const T& x) noexcept { + typedef typename std::make_unsigned::type U; + return T(U(0) - U(x)); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr T allOnes() noexcept { return neg(T(1)); } + +//! Returns `x << y` (shift left logical) by explicitly casting `x` to an unsigned type and back. +template +static ASMJIT_INLINE_NODEBUG constexpr X shl(const X& x, const Y& y) noexcept { + typedef typename std::make_unsigned::type U; + return X(U(x) << y); +} + +//! Returns `x >> y` (shift right logical) by explicitly casting `x` to an unsigned type and back. +template +static ASMJIT_INLINE_NODEBUG constexpr X shr(const X& x, const Y& y) noexcept { + typedef typename std::make_unsigned::type U; + return X(U(x) >> y); +} + +//! Returns `x >> y` (shift right arithmetic) by explicitly casting `x` to a signed type and back. +template +static ASMJIT_INLINE_NODEBUG constexpr X sar(const X& x, const Y& y) noexcept { + typedef typename std::make_signed::type S; + return X(S(x) >> y); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr X ror(const X& x, const Y& y) noexcept { + typedef typename std::make_unsigned::type U; + return X((U(x) >> y) | (U(x) << (bitSizeOf() - U(y)))); +} + +//! Returns `x | (x >> y)` - helper used by some bit manipulation helpers. +template +static ASMJIT_INLINE_NODEBUG constexpr X or_shr(const X& x, const Y& y) noexcept { return X(x | shr(x, y)); } + +//! Returns `x & -x` - extracts lowest set isolated bit (like BLSI instruction). +template +static ASMJIT_INLINE_NODEBUG constexpr T blsi(T x) noexcept { + typedef typename std::make_unsigned::type U; + return T(U(x) & neg(U(x))); +} + +//! Tests whether the given value `x` has `n`th bit set. +template +static ASMJIT_INLINE_NODEBUG constexpr bool bitTest(T x, IndexT n) noexcept { + typedef typename std::make_unsigned::type U; + return (U(x) & (U(1) << asStdInt(n))) != 0; +} + +// Tests whether the given `value` is a consecutive mask of bits that starts at +// the least significant bit. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isLsbMask(const T& value) { + typedef typename std::make_unsigned::type U; + return value && ((U(value) + 1u) & U(value)) == 0; +} + +// Tests whether the given value contains at least one bit or whether it's a +// bit-mask of consecutive bits. +// +// This function is similar to \ref isLsbMask(), but the mask doesn't have to +// start at a least significant bit. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isConsecutiveMask(const T& value) { + typedef typename std::make_unsigned::type U; + return value && isLsbMask((U(value) - 1u) | U(value)); +} + +//! Generates a trailing bit-mask that has `n` least significant (trailing) bits set. +template +static ASMJIT_INLINE_NODEBUG constexpr T lsbMask(const CountT& n) noexcept { + typedef typename std::make_unsigned::type U; + return (sizeof(U) < sizeof(uintptr_t)) + // Prevent undefined behavior by using a larger type than T. + ? T(U((uintptr_t(1) << n) - uintptr_t(1))) + // Prevent undefined behavior by checking `n` before shift. + : n ? T(shr(allOnes(), bitSizeOf() - size_t(n))) : T(0); +} + +//! Generates a leading bit-mask that has `n` most significant (leading) bits set. +template +static ASMJIT_INLINE_NODEBUG constexpr T msbMask(const CountT& n) noexcept { + typedef typename std::make_unsigned::type U; + return (sizeof(U) < sizeof(uintptr_t)) + // Prevent undefined behavior by using a larger type than T. + ? T(allOnes() >> (bitSizeOf() - n)) + // Prevent undefined behavior by performing `n & (nBits - 1)` so it's always within the range. + : T(sar(U(n != 0) << (bitSizeOf() - 1), n ? uint32_t(n - 1) : uint32_t(0))); +} + +//! Returns a bit-mask that has `x` bit set. +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bitMask(const Index& x) noexcept { return (1u << asUInt(x)); } + +//! Returns a bit-mask that has `x` bit set (multiple arguments). +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bitMask(const Index& x, Args... args) noexcept { return bitMask(x) | bitMask(args...); } + +//! Converts a boolean value `b` to zero or full mask (all bits set). +template +static ASMJIT_INLINE_NODEBUG constexpr DstT bitMaskFromBool(SrcT b) noexcept { + typedef typename std::make_unsigned::type U; + return DstT(U(0) - U(b)); +} + +//! Tests whether `a & b` is non-zero. +template +static inline constexpr bool test(A a, B b) noexcept { return (asUInt(a) & asUInt(b)) != 0; } + +//! \cond +namespace Internal { + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint8_t fillTrailingBitsImpl(uint8_t x) noexcept { return or_shr(or_shr(or_shr(x, 1), 2), 4); } + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint16_t fillTrailingBitsImpl(uint16_t x) noexcept { return or_shr(or_shr(or_shr(or_shr(x, 1), 2), 4), 8); } + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t fillTrailingBitsImpl(uint32_t x) noexcept { return or_shr(or_shr(or_shr(or_shr(or_shr(x, 1), 2), 4), 8), 16); } + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint64_t fillTrailingBitsImpl(uint64_t x) noexcept { return or_shr(or_shr(or_shr(or_shr(or_shr(or_shr(x, 1), 2), 4), 8), 16), 32); } +} +//! \endcond + +// Fills all trailing bits right from the first most significant bit set. +template +static ASMJIT_INLINE_NODEBUG constexpr T fillTrailingBits(const T& x) noexcept { + typedef typename std::make_unsigned::type U; + return T(Internal::fillTrailingBitsImpl(U(x))); +} + +// Support - Count Leading/Trailing Zeros +// ====================================== + +//! \cond +namespace Internal { +namespace { + +template +struct BitScanData { T x; uint32_t n; }; + +template +struct BitScanCalc { + static ASMJIT_INLINE_NODEBUG constexpr BitScanData advanceLeft(const BitScanData& data, uint32_t n) noexcept { + return BitScanData { data.x << n, data.n + n }; + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData advanceRight(const BitScanData& data, uint32_t n) noexcept { + return BitScanData { data.x >> n, data.n + n }; + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData clz(const BitScanData& data) noexcept { + return BitScanCalc::clz(advanceLeft(data, data.x & (allOnes() << (bitSizeOf() - N)) ? uint32_t(0) : N)); + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData ctz(const BitScanData& data) noexcept { + return BitScanCalc::ctz(advanceRight(data, data.x & (allOnes() >> (bitSizeOf() - N)) ? uint32_t(0) : N)); + } +}; + +template +struct BitScanCalc { + static ASMJIT_INLINE_NODEBUG constexpr BitScanData clz(const BitScanData& ctx) noexcept { + return BitScanData { 0, ctx.n - uint32_t(ctx.x >> (bitSizeOf() - 1)) }; + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData ctz(const BitScanData& ctx) noexcept { + return BitScanData { 0, ctx.n - uint32_t(ctx.x & 0x1) }; + } +}; + +template +ASMJIT_INLINE_NODEBUG constexpr uint32_t clzFallback(const T& x) noexcept { + return BitScanCalc() / 2u>::clz(BitScanData{x, 1}).n; +} + +template +ASMJIT_INLINE_NODEBUG constexpr uint32_t ctzFallback(const T& x) noexcept { + return BitScanCalc() / 2u>::ctz(BitScanData{x, 1}).n; +} + +template ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const T& x) noexcept { return clzFallback(asUInt(x)); } +template ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const T& x) noexcept { return ctzFallback(asUInt(x)); } + +#if !defined(ASMJIT_NO_INTRINSICS) +# if defined(__GNUC__) +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint32_t& x) noexcept { return uint32_t(__builtin_clz(x)); } +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint64_t& x) noexcept { return uint32_t(__builtin_clzll(x)); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint32_t& x) noexcept { return uint32_t(__builtin_ctz(x)); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint64_t& x) noexcept { return uint32_t(__builtin_ctzll(x)); } +# elif defined(_MSC_VER) +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint32_t& x) noexcept { unsigned long i; _BitScanReverse(&i, x); return uint32_t(i ^ 31); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint32_t& x) noexcept { unsigned long i; _BitScanForward(&i, x); return uint32_t(i); } +# if ASMJIT_ARCH_X86 == 64 || ASMJIT_ARCH_ARM == 64 +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint64_t& x) noexcept { unsigned long i; _BitScanReverse64(&i, x); return uint32_t(i ^ 63); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint64_t& x) noexcept { unsigned long i; _BitScanForward64(&i, x); return uint32_t(i); } +# endif +# endif +#endif + +} // {anonymous} +} // {Internal} +//! \endcond + +//! Count leading zeros in `x` (returns a position of a first bit set in `x`). +//! +//! \note The input MUST NOT be zero, otherwise the result is undefined. +template +static ASMJIT_INLINE_NODEBUG uint32_t clz(T x) noexcept { return Internal::clzImpl(asUInt(x)); } + +//! Count trailing zeros in `x` (returns a position of a first bit set in `x`). +//! +//! \note The input MUST NOT be zero, otherwise the result is undefined. +template +static ASMJIT_INLINE_NODEBUG uint32_t ctz(T x) noexcept { return Internal::ctzImpl(asUInt(x)); } + +template +struct ConstCTZ { + static constexpr uint32_t value = + (kInput & (uint64_t(1) << 0)) ? 0 : + (kInput & (uint64_t(1) << 1)) ? 1 : + (kInput & (uint64_t(1) << 2)) ? 2 : + (kInput & (uint64_t(1) << 3)) ? 3 : + (kInput & (uint64_t(1) << 4)) ? 4 : + (kInput & (uint64_t(1) << 5)) ? 5 : + (kInput & (uint64_t(1) << 6)) ? 6 : + (kInput & (uint64_t(1) << 7)) ? 7 : + (kInput & (uint64_t(1) << 8)) ? 8 : + (kInput & (uint64_t(1) << 9)) ? 9 : + (kInput & (uint64_t(1) << 10)) ? 10 : + (kInput & (uint64_t(1) << 11)) ? 11 : + (kInput & (uint64_t(1) << 12)) ? 12 : + (kInput & (uint64_t(1) << 13)) ? 13 : + (kInput & (uint64_t(1) << 14)) ? 14 : + (kInput & (uint64_t(1) << 15)) ? 15 : + (kInput & (uint64_t(1) << 16)) ? 16 : + (kInput & (uint64_t(1) << 17)) ? 17 : + (kInput & (uint64_t(1) << 18)) ? 18 : + (kInput & (uint64_t(1) << 19)) ? 19 : + (kInput & (uint64_t(1) << 20)) ? 20 : + (kInput & (uint64_t(1) << 21)) ? 21 : + (kInput & (uint64_t(1) << 22)) ? 22 : + (kInput & (uint64_t(1) << 23)) ? 23 : + (kInput & (uint64_t(1) << 24)) ? 24 : + (kInput & (uint64_t(1) << 25)) ? 25 : + (kInput & (uint64_t(1) << 26)) ? 26 : + (kInput & (uint64_t(1) << 27)) ? 27 : + (kInput & (uint64_t(1) << 28)) ? 28 : + (kInput & (uint64_t(1) << 29)) ? 29 : + (kInput & (uint64_t(1) << 30)) ? 30 : + (kInput & (uint64_t(1) << 31)) ? 31 : + (kInput & (uint64_t(1) << 32)) ? 32 : + (kInput & (uint64_t(1) << 33)) ? 33 : + (kInput & (uint64_t(1) << 34)) ? 34 : + (kInput & (uint64_t(1) << 35)) ? 35 : + (kInput & (uint64_t(1) << 36)) ? 36 : + (kInput & (uint64_t(1) << 37)) ? 37 : + (kInput & (uint64_t(1) << 38)) ? 38 : + (kInput & (uint64_t(1) << 39)) ? 39 : + (kInput & (uint64_t(1) << 40)) ? 40 : + (kInput & (uint64_t(1) << 41)) ? 41 : + (kInput & (uint64_t(1) << 42)) ? 42 : + (kInput & (uint64_t(1) << 43)) ? 43 : + (kInput & (uint64_t(1) << 44)) ? 44 : + (kInput & (uint64_t(1) << 45)) ? 45 : + (kInput & (uint64_t(1) << 46)) ? 46 : + (kInput & (uint64_t(1) << 47)) ? 47 : + (kInput & (uint64_t(1) << 48)) ? 48 : + (kInput & (uint64_t(1) << 49)) ? 49 : + (kInput & (uint64_t(1) << 50)) ? 50 : + (kInput & (uint64_t(1) << 51)) ? 51 : + (kInput & (uint64_t(1) << 52)) ? 52 : + (kInput & (uint64_t(1) << 53)) ? 53 : + (kInput & (uint64_t(1) << 54)) ? 54 : + (kInput & (uint64_t(1) << 55)) ? 55 : + (kInput & (uint64_t(1) << 56)) ? 56 : + (kInput & (uint64_t(1) << 57)) ? 57 : + (kInput & (uint64_t(1) << 58)) ? 58 : + (kInput & (uint64_t(1) << 59)) ? 59 : + (kInput & (uint64_t(1) << 60)) ? 60 : + (kInput & (uint64_t(1) << 61)) ? 61 : + (kInput & (uint64_t(1) << 62)) ? 62 : + (kInput & (uint64_t(1) << 63)) ? 63 : 64; +}; + +// Support - PopCnt +// ================ + +// Based on the following resource: +// http://graphics.stanford.edu/~seander/bithacks.html +// +// Alternatively, for a very small number of bits in `x`: +// uint32_t n = 0; +// while (x) { +// x &= x - 1; +// n++; +// } +// return n; + +//! \cond +namespace Internal { + static ASMJIT_INLINE_NODEBUG uint32_t constPopcntImpl(uint32_t x) noexcept { + x = x - ((x >> 1) & 0x55555555u); + x = (x & 0x33333333u) + ((x >> 2) & 0x33333333u); + return (((x + (x >> 4)) & 0x0F0F0F0Fu) * 0x01010101u) >> 24; + } + + static ASMJIT_INLINE_NODEBUG uint32_t constPopcntImpl(uint64_t x) noexcept { +#if ASMJIT_ARCH_BITS >= 64 + x = x - ((x >> 1) & 0x5555555555555555u); + x = (x & 0x3333333333333333u) + ((x >> 2) & 0x3333333333333333u); + return uint32_t((((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Fu) * 0x0101010101010101u) >> 56); +#else + return constPopcntImpl(uint32_t(x >> 32)) + + constPopcntImpl(uint32_t(x & 0xFFFFFFFFu)); +#endif + } + + static ASMJIT_INLINE_NODEBUG uint32_t popcntImpl(uint32_t x) noexcept { +#if defined(__GNUC__) + return uint32_t(__builtin_popcount(x)); +#else + return constPopcntImpl(asUInt(x)); +#endif + } + + static ASMJIT_INLINE_NODEBUG uint32_t popcntImpl(uint64_t x) noexcept { +#if defined(__GNUC__) + return uint32_t(__builtin_popcountll(x)); +#else + return constPopcntImpl(asUInt(x)); +#endif + } +} +//! \endcond + +//! Calculates count of bits in `x`. +template +static ASMJIT_INLINE_NODEBUG uint32_t popcnt(T x) noexcept { return Internal::popcntImpl(asUInt(x)); } + +//! Calculates count of bits in `x` (useful in constant expressions). +template +static ASMJIT_INLINE_NODEBUG uint32_t constPopcnt(T x) noexcept { return Internal::constPopcntImpl(asUInt(x)); } + +// Support - Min/Max +// ================= + +// NOTE: These are constexpr `min()` and `max()` implementations that are not +// exactly the same as `std::min()` and `std::max()`. The return value is not +// a reference to `a` or `b` but it's a new value instead. + +template +static ASMJIT_INLINE_NODEBUG constexpr T min(const T& a, const T& b) noexcept { return b < a ? b : a; } + +template +static ASMJIT_INLINE_NODEBUG constexpr T min(const T& a, const T& b, Args&&... args) noexcept { return min(min(a, b), std::forward(args)...); } + +template +static ASMJIT_INLINE_NODEBUG constexpr T max(const T& a, const T& b) noexcept { return a < b ? b : a; } + +template +static ASMJIT_INLINE_NODEBUG constexpr T max(const T& a, const T& b, Args&&... args) noexcept { return max(max(a, b), std::forward(args)...); } + +// Support - Immediate Helpers +// =========================== + +namespace Internal { + template + struct ImmConv { + static ASMJIT_INLINE_NODEBUG int64_t fromT(const T& x) noexcept { return int64_t(x); } + static ASMJIT_INLINE_NODEBUG T toT(int64_t x) noexcept { return T(uint64_t(x) & Support::allOnes::type>()); } + }; + + template + struct ImmConv { + static ASMJIT_INLINE_NODEBUG int64_t fromT(const T& x) noexcept { return int64_t(bitCast(double(x))); } + static ASMJIT_INLINE_NODEBUG T toT(int64_t x) noexcept { return T(bitCast(x)); } + }; +} + +template +static ASMJIT_INLINE_NODEBUG int64_t immediateFromT(const T& x) noexcept { return Internal::ImmConv::value>::fromT(x); } + +template +static ASMJIT_INLINE_NODEBUG T immediateToT(int64_t x) noexcept { return Internal::ImmConv::value>::toT(x); } + +// Support - Overflow Arithmetic +// ============================= + +//! \cond +namespace Internal { + template + inline T addOverflowFallback(T x, T y, FastUInt8* of) noexcept { + typedef typename std::make_unsigned::type U; + + U result = U(x) + U(y); + *of = FastUInt8(*of | FastUInt8(isUnsigned() ? result < U(x) : T((U(x) ^ ~U(y)) & (U(x) ^ result)) < 0)); + return T(result); + } + + template + inline T subOverflowFallback(T x, T y, FastUInt8* of) noexcept { + typedef typename std::make_unsigned::type U; + + U result = U(x) - U(y); + *of = FastUInt8(*of | FastUInt8(isUnsigned() ? result > U(x) : T((U(x) ^ U(y)) & (U(x) ^ result)) < 0)); + return T(result); + } + + template + inline T mulOverflowFallback(T x, T y, FastUInt8* of) noexcept { + typedef typename Internal::StdInt()>::Type I; + typedef typename std::make_unsigned::type U; + + U mask = allOnes(); + if (std::is_signed::value) { + U prod = U(I(x)) * U(I(y)); + *of = FastUInt8(*of | FastUInt8(I(prod) < I(std::numeric_limits::lowest()) || I(prod) > I(std::numeric_limits::max()))); + return T(I(prod & mask)); + } + else { + U prod = U(x) * U(y); + *of = FastUInt8(*of | FastUInt8((prod & ~mask) != 0)); + return T(prod & mask); + } + } + + template<> + inline int64_t mulOverflowFallback(int64_t x, int64_t y, FastUInt8* of) noexcept { + int64_t result = int64_t(uint64_t(x) * uint64_t(y)); + *of = FastUInt8(*of | FastUInt8(x && (result / x != y))); + return result; + } + + template<> + inline uint64_t mulOverflowFallback(uint64_t x, uint64_t y, FastUInt8* of) noexcept { + uint64_t result = x * y; + *of = FastUInt8(*of | FastUInt8(y != 0 && allOnes() / y < x)); + return result; + } + + // These can be specialized. + template inline T addOverflowImpl(const T& x, const T& y, FastUInt8* of) noexcept { return addOverflowFallback(x, y, of); } + template inline T subOverflowImpl(const T& x, const T& y, FastUInt8* of) noexcept { return subOverflowFallback(x, y, of); } + template inline T mulOverflowImpl(const T& x, const T& y, FastUInt8* of) noexcept { return mulOverflowFallback(x, y, of); } + +#if defined(__GNUC__) && !defined(ASMJIT_NO_INTRINSICS) +#if defined(__clang__) || __GNUC__ >= 5 +#define ASMJIT_ARITH_OVERFLOW_SPECIALIZE(FUNC, T, RESULT_T, BUILTIN) \ + template<> \ + inline T FUNC(const T& x, const T& y, FastUInt8* of) noexcept { \ + RESULT_T result; \ + *of = FastUInt8(*of | (BUILTIN((RESULT_T)x, (RESULT_T)y, &result))); \ + return T(result); \ + } + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, int32_t , int , __builtin_sadd_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint32_t, unsigned int , __builtin_uadd_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, int64_t , long long , __builtin_saddll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint64_t, unsigned long long, __builtin_uaddll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, int32_t , int , __builtin_ssub_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint32_t, unsigned int , __builtin_usub_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, int64_t , long long , __builtin_ssubll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint64_t, unsigned long long, __builtin_usubll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, int32_t , int , __builtin_smul_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, uint32_t, unsigned int , __builtin_umul_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, int64_t , long long , __builtin_smulll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, uint64_t, unsigned long long, __builtin_umulll_overflow) +#undef ASMJIT_ARITH_OVERFLOW_SPECIALIZE +#endif +#endif + + // There is a bug in MSVC that makes these specializations unusable, maybe in the future... +#if defined(_MSC_VER) && 0 +#define ASMJIT_ARITH_OVERFLOW_SPECIALIZE(FUNC, T, ALT_T, BUILTIN) \ + template<> \ + inline T FUNC(T x, T y, FastUInt8* of) noexcept { \ + ALT_T result; \ + *of = FastUInt8(*of | BUILTIN(0, (ALT_T)x, (ALT_T)y, &result)); \ + return T(result); \ + } + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint32_t, unsigned int , _addcarry_u32 ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint32_t, unsigned int , _subborrow_u32) +#if ARCH_BITS >= 64 + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint64_t, unsigned __int64 , _addcarry_u64 ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint64_t, unsigned __int64 , _subborrow_u64) +#endif +#undef ASMJIT_ARITH_OVERFLOW_SPECIALIZE +#endif +} // {Internal} +//! \endcond + +template +static inline T addOverflow(const T& x, const T& y, FastUInt8* of) noexcept { return T(Internal::addOverflowImpl(asStdInt(x), asStdInt(y), of)); } + +template +static inline T subOverflow(const T& x, const T& y, FastUInt8* of) noexcept { return T(Internal::subOverflowImpl(asStdInt(x), asStdInt(y), of)); } + +template +static inline T mulOverflow(const T& x, const T& y, FastUInt8* of) noexcept { return T(Internal::mulOverflowImpl(asStdInt(x), asStdInt(y), of)); } + +// Support - Alignment +// =================== + +template +static ASMJIT_INLINE_NODEBUG constexpr bool isAligned(X base, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return ((U)base % (U)alignment) == 0; +} + +//! Tests whether the `x` is a power of two (only one bit is set). +template +static ASMJIT_INLINE_NODEBUG constexpr bool isPowerOf2(T x) noexcept { + typedef typename std::make_unsigned::type U; + return x && !(U(x) & (U(x) - U(1))); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr X alignUp(X x, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return (X)( ((U)x + ((U)(alignment) - 1u)) & ~((U)(alignment) - 1u) ); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr T alignUpPowerOf2(T x) noexcept { + typedef typename Internal::StdInt::Type U; + return (T)(fillTrailingBits(U(x) - 1u) + 1u); +} + +//! Returns either zero or a positive difference between `base` and `base` when +//! aligned to `alignment`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::StdInt::Type alignUpDiff(X base, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return alignUp(U(base), alignment) - U(base); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr X alignDown(X x, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return (X)( (U)x & ~((U)(alignment) - 1u) ); +} + +// Support - NumGranularized +// ========================= + +//! Calculates the number of elements that would be required if `base` is +//! granularized by `granularity`. This function can be used to calculate +//! the number of BitWords to represent N bits, for example. +template +static ASMJIT_INLINE_NODEBUG constexpr X numGranularized(X base, Y granularity) noexcept { + typedef typename Internal::StdInt::Type U; + return X((U(base) + U(granularity) - 1) / U(granularity)); +} + +// Support - IsBetween +// =================== + +//! Checks whether `x` is greater than or equal to `a` and lesser than or equal to `b`. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isBetween(const T& x, const T& a, const T& b) noexcept { + return x >= a && x <= b; +} + +// Support - IsInt & IsUInt +// ======================== + +//! Checks whether the given integer `x` can be casted to a 4-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt4(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? isBetween(S(x), -8, 7) : U(x) <= U(7u); +} + +//! Checks whether the given integer `x` can be casted to a 7-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt7(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? isBetween(S(x), -64, 63) : U(x) <= U(63u); +} + +//! Checks whether the given integer `x` can be casted to an 8-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt8(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 1 || isBetween(S(x), -128, 127) : U(x) <= U(127u); +} + +//! Checks whether the given integer `x` can be casted to a 9-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt9(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 1 || isBetween(S(x), -256, 255) + : sizeof(T) <= 1 || U(x) <= U(255u); +} + +//! Checks whether the given integer `x` can be casted to a 10-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt10(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 1 || isBetween(S(x), -512, 511) + : sizeof(T) <= 1 || U(x) <= U(511u); +} + +//! Checks whether the given integer `x` can be casted to a 16-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt16(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 2 || isBetween(S(x), -32768, 32767) + : sizeof(T) <= 1 || U(x) <= U(32767u); +} + +//! Checks whether the given integer `x` can be casted to a 32-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt32(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 4 || isBetween(S(x), -2147483647 - 1, 2147483647) + : sizeof(T) <= 2 || U(x) <= U(2147483647u); +} + +//! Checks whether the given integer `x` can be casted to a 4-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt4(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? x >= T(0) && x <= T(15) + : U(x) <= U(15u); +} + +//! Checks whether the given integer `x` can be casted to an 8-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt8(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 1 || T(x) <= T(255)) && x >= T(0) + : (sizeof(T) <= 1 || U(x) <= U(255u)); +} + +//! Checks whether the given integer `x` can be casted to a 12-bit unsigned integer (ARM specific). +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt12(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 1 || T(x) <= T(4095)) && x >= T(0) + : (sizeof(T) <= 1 || U(x) <= U(4095u)); +} + +//! Checks whether the given integer `x` can be casted to a 16-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt16(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 2 || T(x) <= T(65535)) && x >= T(0) + : (sizeof(T) <= 2 || U(x) <= U(65535u)); +} + +//! Checks whether the given integer `x` can be casted to a 32-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt32(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 4 || T(x) <= T(4294967295u)) && x >= T(0) + : (sizeof(T) <= 4 || U(x) <= U(4294967295u)); +} + +//! Checks whether the given integer `x` can be casted to a 32-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isIntOrUInt32(T x) noexcept { + return sizeof(T) <= 4 ? true : (uint32_t(uint64_t(x) >> 32) + 1u) <= 1u; +} + +static bool ASMJIT_INLINE_NODEBUG isEncodableOffset32(int32_t offset, uint32_t nBits) noexcept { + uint32_t nRev = 32 - nBits; + return Support::sar(Support::shl(offset, nRev), nRev) == offset; +} + +static bool ASMJIT_INLINE_NODEBUG isEncodableOffset64(int64_t offset, uint32_t nBits) noexcept { + uint32_t nRev = 64 - nBits; + return Support::sar(Support::shl(offset, nRev), nRev) == offset; +} + +// Support - ByteSwap +// ================== + +static ASMJIT_INLINE_NODEBUG uint16_t byteswap16(uint16_t x) noexcept { + return uint16_t(((x >> 8) & 0xFFu) | ((x & 0xFFu) << 8)); +} + +static ASMJIT_INLINE_NODEBUG uint32_t byteswap32(uint32_t x) noexcept { + return (x << 24) | (x >> 24) | ((x << 8) & 0x00FF0000u) | ((x >> 8) & 0x0000FF00); +} + +static ASMJIT_INLINE_NODEBUG uint64_t byteswap64(uint64_t x) noexcept { +#if (defined(__GNUC__) || defined(__clang__)) && !defined(ASMJIT_NO_INTRINSICS) + return uint64_t(__builtin_bswap64(uint64_t(x))); +#elif defined(_MSC_VER) && !defined(ASMJIT_NO_INTRINSICS) + return uint64_t(_byteswap_uint64(uint64_t(x))); +#else + return (uint64_t(byteswap32(uint32_t(uint64_t(x) >> 32 ))) ) | + (uint64_t(byteswap32(uint32_t(uint64_t(x) & 0xFFFFFFFFu))) << 32) ; +#endif +} + +// Support - BytePack & Unpack +// =========================== + +//! Pack four 8-bit integer into a 32-bit integer as it is an array of `{b0,b1,b2,b3}`. +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bytepack32_4x8(uint32_t a, uint32_t b, uint32_t c, uint32_t d) noexcept { + return ASMJIT_ARCH_LE ? (a | (b << 8) | (c << 16) | (d << 24)) + : (d | (c << 8) | (b << 16) | (a << 24)); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t unpackU32At0(T x) noexcept { return ASMJIT_ARCH_LE ? uint32_t(uint64_t(x) & 0xFFFFFFFFu) : uint32_t(uint64_t(x) >> 32); } +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t unpackU32At1(T x) noexcept { return ASMJIT_ARCH_BE ? uint32_t(uint64_t(x) & 0xFFFFFFFFu) : uint32_t(uint64_t(x) >> 32); } + +// Support - Position of byte (in bit-shift) +// ========================================= + +static ASMJIT_INLINE_NODEBUG uint32_t byteShiftOfDWordStruct(uint32_t index) noexcept { + return ASMJIT_ARCH_LE ? index * 8 : (uint32_t(sizeof(uint32_t)) - 1u - index) * 8; +} + +// Support - String Utilities +// ========================== + +template +static ASMJIT_INLINE_NODEBUG constexpr T asciiToLower(T c) noexcept { return T(c ^ T(T(c >= T('A') && c <= T('Z')) << 5)); } + +template +static ASMJIT_INLINE_NODEBUG constexpr T asciiToUpper(T c) noexcept { return T(c ^ T(T(c >= T('a') && c <= T('z')) << 5)); } + +static ASMJIT_INLINE_NODEBUG size_t strLen(const char* s, size_t maxSize) noexcept { + size_t i = 0; + while (i < maxSize && s[i] != '\0') + i++; + return i; +} + +static ASMJIT_INLINE_NODEBUG constexpr uint32_t hashRound(uint32_t hash, uint32_t c) noexcept { return hash * 65599 + c; } + +// Gets a hash of the given string `data` of size `size`. Size must be valid +// as this function doesn't check for a null terminator and allows it in the +// middle of the string. +static ASMJIT_INLINE_NODEBUG uint32_t hashString(const char* data, size_t size) noexcept { + uint32_t hashCode = 0; + for (uint32_t i = 0; i < size; i++) + hashCode = hashRound(hashCode, uint8_t(data[i])); + return hashCode; +} + +static ASMJIT_INLINE_NODEBUG const char* findPackedString(const char* p, uint32_t id) noexcept { + uint32_t i = 0; + while (i < id) { + while (p[0]) + p++; + p++; + i++; + } + return p; +} + +//! Compares two string views. +static ASMJIT_FORCE_INLINE int compareStringViews(const char* aData, size_t aSize, const char* bData, size_t bSize) noexcept { + size_t size = Support::min(aSize, bSize); + + for (size_t i = 0; i < size; i++) { + int c = int(uint8_t(aData[i])) - int(uint8_t(bData[i])); + if (c != 0) + return c; + } + + return int(aSize) - int(bSize); +} + +// Support - Memory Read Access - 8 Bits +// ===================================== + +static ASMJIT_INLINE_NODEBUG uint8_t readU8(const void* p) noexcept { return static_cast(p)[0]; } +static ASMJIT_INLINE_NODEBUG int8_t readI8(const void* p) noexcept { return static_cast(p)[0]; } + +// Support - Memory Read Access - 16 Bits +// ====================================== + +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16x(const void* p) noexcept { + typedef typename Internal::AliasedUInt::T U16AlignedToN; + uint16_t x = static_cast(p)[0]; + return BO == ByteOrder::kNative ? x : byteswap16(x); +} + +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16u(const void* p) noexcept { return readU16x(p); } +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16uLE(const void* p) noexcept { return readU16x(p); } +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16uBE(const void* p) noexcept { return readU16x(p); } + +static ASMJIT_INLINE_NODEBUG uint16_t readU16a(const void* p) noexcept { return readU16x(p); } +static ASMJIT_INLINE_NODEBUG uint16_t readU16aLE(const void* p) noexcept { return readU16x(p); } +static ASMJIT_INLINE_NODEBUG uint16_t readU16aBE(const void* p) noexcept { return readU16x(p); } + +template +static ASMJIT_INLINE_NODEBUG int16_t readI16x(const void* p) noexcept { return int16_t(readU16x(p)); } + +template +static ASMJIT_INLINE_NODEBUG int16_t readI16u(const void* p) noexcept { return int16_t(readU16x(p)); } +template +static ASMJIT_INLINE_NODEBUG int16_t readI16uLE(const void* p) noexcept { return int16_t(readU16x(p)); } +template +static ASMJIT_INLINE_NODEBUG int16_t readI16uBE(const void* p) noexcept { return int16_t(readU16x(p)); } + +static ASMJIT_INLINE_NODEBUG int16_t readI16a(const void* p) noexcept { return int16_t(readU16x(p)); } +static ASMJIT_INLINE_NODEBUG int16_t readI16aLE(const void* p) noexcept { return int16_t(readU16x(p)); } +static ASMJIT_INLINE_NODEBUG int16_t readI16aBE(const void* p) noexcept { return int16_t(readU16x(p)); } + +// Support - Memory Read Access - 24 Bits +// ====================================== + +template +static inline uint32_t readU24u(const void* p) noexcept { + uint32_t b0 = readU8(static_cast(p) + (BO == ByteOrder::kLE ? 2u : 0u)); + uint32_t b1 = readU8(static_cast(p) + 1u); + uint32_t b2 = readU8(static_cast(p) + (BO == ByteOrder::kLE ? 0u : 2u)); + return (b0 << 16) | (b1 << 8) | b2; +} + +static inline uint32_t readU24uLE(const void* p) noexcept { return readU24u(p); } +static inline uint32_t readU24uBE(const void* p) noexcept { return readU24u(p); } + +// Support - Memory Read Access - 32 Bits +// ====================================== + +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32x(const void* p) noexcept { + typedef typename Internal::AliasedUInt::T U32AlignedToN; + uint32_t x = static_cast(p)[0]; + return BO == ByteOrder::kNative ? x : byteswap32(x); +} + +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32u(const void* p) noexcept { return readU32x(p); } +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32uLE(const void* p) noexcept { return readU32x(p); } +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32uBE(const void* p) noexcept { return readU32x(p); } + +static ASMJIT_INLINE_NODEBUG uint32_t readU32a(const void* p) noexcept { return readU32x(p); } +static ASMJIT_INLINE_NODEBUG uint32_t readU32aLE(const void* p) noexcept { return readU32x(p); } +static ASMJIT_INLINE_NODEBUG uint32_t readU32aBE(const void* p) noexcept { return readU32x(p); } + +template +static ASMJIT_INLINE_NODEBUG uint32_t readI32x(const void* p) noexcept { return int32_t(readU32x(p)); } + +template +static ASMJIT_INLINE_NODEBUG int32_t readI32u(const void* p) noexcept { return int32_t(readU32x(p)); } +template +static ASMJIT_INLINE_NODEBUG int32_t readI32uLE(const void* p) noexcept { return int32_t(readU32x(p)); } +template +static ASMJIT_INLINE_NODEBUG int32_t readI32uBE(const void* p) noexcept { return int32_t(readU32x(p)); } + +static ASMJIT_INLINE_NODEBUG int32_t readI32a(const void* p) noexcept { return int32_t(readU32x(p)); } +static ASMJIT_INLINE_NODEBUG int32_t readI32aLE(const void* p) noexcept { return int32_t(readU32x(p)); } +static ASMJIT_INLINE_NODEBUG int32_t readI32aBE(const void* p) noexcept { return int32_t(readU32x(p)); } + +// Support - Memory Read Access - 64 Bits +// ====================================== + +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64x(const void* p) noexcept { + typedef typename Internal::AliasedUInt::T U64AlignedToN; + uint64_t x = static_cast(p)[0]; + return BO == ByteOrder::kNative ? x : byteswap64(x); +} + +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64u(const void* p) noexcept { return readU64x(p); } +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64uLE(const void* p) noexcept { return readU64x(p); } +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64uBE(const void* p) noexcept { return readU64x(p); } + +static ASMJIT_INLINE_NODEBUG uint64_t readU64a(const void* p) noexcept { return readU64x(p); } +static ASMJIT_INLINE_NODEBUG uint64_t readU64aLE(const void* p) noexcept { return readU64x(p); } +static ASMJIT_INLINE_NODEBUG uint64_t readU64aBE(const void* p) noexcept { return readU64x(p); } + +template +static ASMJIT_INLINE_NODEBUG int64_t readI64x(const void* p) noexcept { return int64_t(readU64x(p)); } + +template +static ASMJIT_INLINE_NODEBUG int64_t readI64u(const void* p) noexcept { return int64_t(readU64x(p)); } +template +static ASMJIT_INLINE_NODEBUG int64_t readI64uLE(const void* p) noexcept { return int64_t(readU64x(p)); } +template +static ASMJIT_INLINE_NODEBUG int64_t readI64uBE(const void* p) noexcept { return int64_t(readU64x(p)); } + +static ASMJIT_INLINE_NODEBUG int64_t readI64a(const void* p) noexcept { return int64_t(readU64x(p)); } +static ASMJIT_INLINE_NODEBUG int64_t readI64aLE(const void* p) noexcept { return int64_t(readU64x(p)); } +static ASMJIT_INLINE_NODEBUG int64_t readI64aBE(const void* p) noexcept { return int64_t(readU64x(p)); } + +// Support - Memory Write Access - 8 Bits +// ====================================== + +static ASMJIT_INLINE_NODEBUG void writeU8(void* p, uint8_t x) noexcept { static_cast(p)[0] = x; } +static ASMJIT_INLINE_NODEBUG void writeI8(void* p, int8_t x) noexcept { static_cast(p)[0] = x; } + +// Support - Memory Write Access - 16 Bits +// ======================================= + +template +static ASMJIT_INLINE_NODEBUG void writeU16x(void* p, uint16_t x) noexcept { + typedef typename Internal::AliasedUInt::T U16AlignedToN; + static_cast(p)[0] = BO == ByteOrder::kNative ? x : byteswap16(x); +} + +template +static ASMJIT_INLINE_NODEBUG void writeU16uLE(void* p, uint16_t x) noexcept { writeU16x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU16uBE(void* p, uint16_t x) noexcept { writeU16x(p, x); } + +static ASMJIT_INLINE_NODEBUG void writeU16a(void* p, uint16_t x) noexcept { writeU16x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU16aLE(void* p, uint16_t x) noexcept { writeU16x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU16aBE(void* p, uint16_t x) noexcept { writeU16x(p, x); } + + +template +static ASMJIT_INLINE_NODEBUG void writeI16x(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } + +template +static ASMJIT_INLINE_NODEBUG void writeI16uLE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI16uBE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } + +static ASMJIT_INLINE_NODEBUG void writeI16a(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI16aLE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI16aBE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } + +// Support - Memory Write Access - 24 Bits +// ======================================= + +template +static inline void writeU24u(void* p, uint32_t v) noexcept { + static_cast(p)[0] = uint8_t((v >> (BO == ByteOrder::kLE ? 0 : 16)) & 0xFFu); + static_cast(p)[1] = uint8_t((v >> 8) & 0xFFu); + static_cast(p)[2] = uint8_t((v >> (BO == ByteOrder::kLE ? 16 : 0)) & 0xFFu); +} + +static inline void writeU24uLE(void* p, uint32_t v) noexcept { writeU24u(p, v); } +static inline void writeU24uBE(void* p, uint32_t v) noexcept { writeU24u(p, v); } + +// Support - Memory Write Access - 32 Bits +// ======================================= + +template +static ASMJIT_INLINE_NODEBUG void writeU32x(void* p, uint32_t x) noexcept { + typedef typename Internal::AliasedUInt::T U32AlignedToN; + static_cast(p)[0] = (BO == ByteOrder::kNative) ? x : Support::byteswap32(x); +} + +template +static ASMJIT_INLINE_NODEBUG void writeU32u(void* p, uint32_t x) noexcept { writeU32x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU32uLE(void* p, uint32_t x) noexcept { writeU32x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU32uBE(void* p, uint32_t x) noexcept { writeU32x(p, x); } + +static ASMJIT_INLINE_NODEBUG void writeU32a(void* p, uint32_t x) noexcept { writeU32x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU32aLE(void* p, uint32_t x) noexcept { writeU32x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU32aBE(void* p, uint32_t x) noexcept { writeU32x(p, x); } + +template +static ASMJIT_INLINE_NODEBUG void writeI32x(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } + +template +static ASMJIT_INLINE_NODEBUG void writeI32u(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI32uLE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI32uBE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } + +static ASMJIT_INLINE_NODEBUG void writeI32a(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI32aLE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI32aBE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } + +// Support - Memory Write Access - 64 Bits +// ======================================= + +template +static ASMJIT_INLINE_NODEBUG void writeU64x(void* p, uint64_t x) noexcept { + typedef typename Internal::AliasedUInt::T U64AlignedToN; + static_cast(p)[0] = BO == ByteOrder::kNative ? x : byteswap64(x); +} + +template +static ASMJIT_INLINE_NODEBUG void writeU64u(void* p, uint64_t x) noexcept { writeU64x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU64uLE(void* p, uint64_t x) noexcept { writeU64x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU64uBE(void* p, uint64_t x) noexcept { writeU64x(p, x); } + +static ASMJIT_INLINE_NODEBUG void writeU64a(void* p, uint64_t x) noexcept { writeU64x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU64aLE(void* p, uint64_t x) noexcept { writeU64x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU64aBE(void* p, uint64_t x) noexcept { writeU64x(p, x); } + +template +static ASMJIT_INLINE_NODEBUG void writeI64x(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } + +template +static ASMJIT_INLINE_NODEBUG void writeI64u(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI64uLE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI64uBE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } + +static ASMJIT_INLINE_NODEBUG void writeI64a(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI64aLE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI64aBE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } + +// Support - Operators +// =================== + +//! \cond INTERNAL +struct Set { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { DebugUtils::unused(x); return y; } }; +struct SetNot { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { DebugUtils::unused(x); return ~y; } }; +struct And { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x & y; } }; +struct AndNot { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x & ~y; } }; +struct NotAnd { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return ~x & y; } }; +struct Or { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x | y; } }; +struct Xor { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x ^ y; } }; +struct Add { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x + y; } }; +struct Sub { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x - y; } }; +struct Min { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return min(x, y); } }; +struct Max { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return max(x, y); } }; +//! \endcond + +// Support - BitWordIterator +// ========================= + +//! Iterates over each bit in a number which is set to 1. +//! +//! Example of use: +//! +//! ``` +//! uint32_t bitsToIterate = 0x110F; +//! Support::BitWordIterator it(bitsToIterate); +//! +//! while (it.hasNext()) { +//! uint32_t bitIndex = it.next(); +//! std::printf("Bit at %u is set\n", unsigned(bitIndex)); +//! } +//! ``` +template +class BitWordIterator { +public: + ASMJIT_INLINE_NODEBUG explicit BitWordIterator(T bitWord) noexcept + : _bitWord(bitWord) {} + + ASMJIT_INLINE_NODEBUG void init(T bitWord) noexcept { _bitWord = bitWord; } + ASMJIT_INLINE_NODEBUG bool hasNext() const noexcept { return _bitWord != 0; } + + ASMJIT_FORCE_INLINE uint32_t next() noexcept { + ASMJIT_ASSERT(_bitWord != 0); + uint32_t index = ctz(_bitWord); + _bitWord &= T(_bitWord - 1); + return index; + } + + T _bitWord; +}; + +// Support - BitVectorOps +// ====================== + +//! \cond +namespace Internal { + template + static ASMJIT_FORCE_INLINE void bitVectorOp(T* buf, size_t index, size_t count) noexcept { + if (count == 0) + return; + + const size_t kTSizeInBits = bitSizeOf(); + size_t vecIndex = index / kTSizeInBits; // T[] + size_t bitIndex = index % kTSizeInBits; // T[][] + + buf += vecIndex; + + // The first BitWord requires special handling to preserve bits outside the fill region. + const T kFillMask = allOnes(); + size_t firstNBits = min(kTSizeInBits - bitIndex, count); + + buf[0] = OperatorT::op(buf[0], (kFillMask >> (kTSizeInBits - firstNBits)) << bitIndex); + buf++; + count -= firstNBits; + + // All bits between the first and last affected BitWords can be just filled. + while (count >= kTSizeInBits) { + buf[0] = FullWordOpT::op(buf[0], kFillMask); + buf++; + count -= kTSizeInBits; + } + + // The last BitWord requires special handling as well + if (count) + buf[0] = OperatorT::op(buf[0], kFillMask >> (kTSizeInBits - count)); + } +} +//! \endcond + +//! Sets bit in a bit-vector `buf` at `index`. +template +static ASMJIT_INLINE_NODEBUG bool bitVectorGetBit(T* buf, size_t index) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + + size_t vecIndex = index / kTSizeInBits; + size_t bitIndex = index % kTSizeInBits; + + return bool((buf[vecIndex] >> bitIndex) & 0x1u); +} + +//! Sets bit in a bit-vector `buf` at `index` to `value`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorSetBit(T* buf, size_t index, bool value) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + + size_t vecIndex = index / kTSizeInBits; + size_t bitIndex = index % kTSizeInBits; + + T bitMask = T(1u) << bitIndex; + if (value) + buf[vecIndex] |= bitMask; + else + buf[vecIndex] &= ~bitMask; +} + +//! Sets bit in a bit-vector `buf` at `index` to `value`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorFlipBit(T* buf, size_t index) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + + size_t vecIndex = index / kTSizeInBits; + size_t bitIndex = index % kTSizeInBits; + + T bitMask = T(1u) << bitIndex; + buf[vecIndex] ^= bitMask; +} + +//! Fills `count` bits in bit-vector `buf` starting at bit-index `index`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorFill(T* buf, size_t index, size_t count) noexcept { Internal::bitVectorOp(buf, index, count); } + +//! Clears `count` bits in bit-vector `buf` starting at bit-index `index`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorClear(T* buf, size_t index, size_t count) noexcept { Internal::bitVectorOp(buf, index, count); } + +template +static ASMJIT_FORCE_INLINE size_t bitVectorIndexOf(T* buf, size_t start, bool value) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + size_t vecIndex = start / kTSizeInBits; // T[] + size_t bitIndex = start % kTSizeInBits; // T[][] + + T* p = buf + vecIndex; + + // We always look for zeros, if value is `true` we have to flip all bits before the search. + const T kFillMask = allOnes(); + const T kFlipMask = value ? T(0) : kFillMask; + + // The first BitWord requires special handling as there are some bits we want to ignore. + T bits = (*p ^ kFlipMask) & (kFillMask << bitIndex); + for (;;) { + if (bits) + return (size_t)(p - buf) * kTSizeInBits + ctz(bits); + bits = *++p ^ kFlipMask; + } +} + +// Support - BitVectorIterator +// =========================== + +template +class BitVectorIterator { +public: + const T* _ptr; + size_t _idx; + size_t _end; + T _current; + + ASMJIT_INLINE_NODEBUG BitVectorIterator(const BitVectorIterator& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG BitVectorIterator(const T* data, size_t numBitWords, size_t start = 0) noexcept { + init(data, numBitWords, start); + } + + ASMJIT_FORCE_INLINE void init(const T* data, size_t numBitWords, size_t start = 0) noexcept { + const T* ptr = data + (start / bitSizeOf()); + size_t idx = alignDown(start, bitSizeOf()); + size_t end = numBitWords * bitSizeOf(); + + T bitWord = T(0); + if (idx < end) { + bitWord = *ptr++ & (allOnes() << (start % bitSizeOf())); + while (!bitWord && (idx += bitSizeOf()) < end) + bitWord = *ptr++; + } + + _ptr = ptr; + _idx = idx; + _end = end; + _current = bitWord; + } + + ASMJIT_INLINE_NODEBUG bool hasNext() const noexcept { + return _current != T(0); + } + + ASMJIT_FORCE_INLINE size_t next() noexcept { + T bitWord = _current; + ASMJIT_ASSERT(bitWord != T(0)); + + uint32_t bit = ctz(bitWord); + bitWord &= T(bitWord - 1u); + + size_t n = _idx + bit; + while (!bitWord && (_idx += bitSizeOf()) < _end) + bitWord = *_ptr++; + + _current = bitWord; + return n; + } + + ASMJIT_FORCE_INLINE size_t peekNext() const noexcept { + ASMJIT_ASSERT(_current != T(0)); + return _idx + ctz(_current); + } +}; + +// Support - BitVectorOpIterator +// ============================= + +template +class BitVectorOpIterator { +public: + enum : uint32_t { + kTSizeInBits = bitSizeOf() + }; + + const T* _aPtr; + const T* _bPtr; + size_t _idx; + size_t _end; + T _current; + + ASMJIT_INLINE_NODEBUG BitVectorOpIterator(const T* aData, const T* bData, size_t numBitWords, size_t start = 0) noexcept { + init(aData, bData, numBitWords, start); + } + + ASMJIT_FORCE_INLINE void init(const T* aData, const T* bData, size_t numBitWords, size_t start = 0) noexcept { + const T* aPtr = aData + (start / bitSizeOf()); + const T* bPtr = bData + (start / bitSizeOf()); + size_t idx = alignDown(start, bitSizeOf()); + size_t end = numBitWords * bitSizeOf(); + + T bitWord = T(0); + if (idx < end) { + bitWord = OperatorT::op(*aPtr++, *bPtr++) & (allOnes() << (start % bitSizeOf())); + while (!bitWord && (idx += kTSizeInBits) < end) + bitWord = OperatorT::op(*aPtr++, *bPtr++); + } + + _aPtr = aPtr; + _bPtr = bPtr; + _idx = idx; + _end = end; + _current = bitWord; + } + + ASMJIT_INLINE_NODEBUG bool hasNext() noexcept { + return _current != T(0); + } + + ASMJIT_FORCE_INLINE size_t next() noexcept { + T bitWord = _current; + ASMJIT_ASSERT(bitWord != T(0)); + + uint32_t bit = ctz(bitWord); + bitWord &= T(bitWord - 1u); + + size_t n = _idx + bit; + while (!bitWord && (_idx += kTSizeInBits) < _end) + bitWord = OperatorT::op(*_aPtr++, *_bPtr++); + + _current = bitWord; + return n; + } +}; + +// Support - Sorting +// ================= + +//! Sort order. +enum class SortOrder : uint32_t { + //!< Ascending order. + kAscending = 0, + //!< Descending order. + kDescending = 1 +}; + +//! A helper class that provides comparison of any user-defined type that +//! implements `<` and `>` operators (primitive types are supported as well). +template +struct Compare { + template + ASMJIT_INLINE_NODEBUG int operator()(const A& a, const B& b) const noexcept { + return kOrder == SortOrder::kAscending ? int(a > b) - int(a < b) : int(a < b) - int(a > b); + } +}; + +//! Insertion sort. +template> +static inline void iSort(T* base, size_t size, const CompareT& cmp = CompareT()) noexcept { + for (T* pm = base + 1; pm < base + size; pm++) + for (T* pl = pm; pl > base && cmp(pl[-1], pl[0]) > 0; pl--) + std::swap(pl[-1], pl[0]); +} + +//! \cond +namespace Internal { + //! Quick-sort implementation. + template + struct QSortImpl { + enum : size_t { + kStackSize = 64 * 2, + kISortThreshold = 7 + }; + + // Based on "PDCLib - Public Domain C Library" and rewritten to C++. + static void sort(T* base, size_t size, const CompareT& cmp) noexcept { + T* end = base + size; + T* stack[kStackSize]; + T** stackptr = stack; + + for (;;) { + if ((size_t)(end - base) > kISortThreshold) { + // We work from second to last - first will be pivot element. + T* pi = base + 1; + T* pj = end - 1; + std::swap(base[(size_t)(end - base) / 2], base[0]); + + if (cmp(*pi , *pj ) > 0) std::swap(*pi , *pj ); + if (cmp(*base, *pj ) > 0) std::swap(*base, *pj ); + if (cmp(*pi , *base) > 0) std::swap(*pi , *base); + + // Now we have the median for pivot element, entering main loop. + for (;;) { + while (pi < pj && cmp(*++pi, *base) < 0) continue; // Move `i` right until `*i >= pivot`. + while (pj > base && cmp(*--pj, *base) > 0) continue; // Move `j` left until `*j <= pivot`. + + if (pi > pj) break; + std::swap(*pi, *pj); + } + + // Move pivot into correct place. + std::swap(*base, *pj); + + // Larger subfile base / end to stack, sort smaller. + if (pj - base > end - pi) { + // Left is larger. + *stackptr++ = base; + *stackptr++ = pj; + base = pi; + } + else { + // Right is larger. + *stackptr++ = pi; + *stackptr++ = end; + end = pj; + } + ASMJIT_ASSERT(stackptr <= stack + kStackSize); + } + else { + // UB sanitizer doesn't like applying offset to a nullptr base. + if (base != end) + iSort(base, (size_t)(end - base), cmp); + + if (stackptr == stack) + break; + + end = *--stackptr; + base = *--stackptr; + } + } + } + }; +} +//! \endcond + +//! Quick sort implementation. +//! +//! The main reason to provide a custom qsort implementation is that we needed something that will +//! never throw `bad_alloc` exception. This implementation doesn't use dynamic memory allocation. +template> +static ASMJIT_INLINE_NODEBUG void qSort(T* base, size_t size, const CompareT& cmp = CompareT()) noexcept { + Internal::QSortImpl::sort(base, size, cmp); +} + +// Support - ReverseIterator +// ========================= + +//! Reverse iterator to avoid including `` header for iteration over arrays, specialized for +//! AsmJit use (noexcept by design). +template +class ArrayReverseIterator { +public: + //! \name Members + //! \{ + + T* _ptr {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG constexpr ArrayReverseIterator() noexcept = default; + ASMJIT_INLINE_NODEBUG constexpr ArrayReverseIterator(const ArrayReverseIterator& other) noexcept = default; + ASMJIT_INLINE_NODEBUG constexpr ArrayReverseIterator(T* ptr) noexcept : _ptr(ptr) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator=(const ArrayReverseIterator& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG bool operator==(const T* other) const noexcept { return _ptr == other; } + ASMJIT_INLINE_NODEBUG bool operator==(const ArrayReverseIterator& other) const noexcept { return _ptr == other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator!=(const T* other) const noexcept { return _ptr != other; } + ASMJIT_INLINE_NODEBUG bool operator!=(const ArrayReverseIterator& other) const noexcept { return _ptr != other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator<(const T* other) const noexcept { return _ptr < other; } + ASMJIT_INLINE_NODEBUG bool operator<(const ArrayReverseIterator& other) const noexcept { return _ptr < other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator<=(const T* other) const noexcept { return _ptr <= other; } + ASMJIT_INLINE_NODEBUG bool operator<=(const ArrayReverseIterator& other) const noexcept { return _ptr <= other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator>(const T* other) const noexcept { return _ptr > other; } + ASMJIT_INLINE_NODEBUG bool operator>(const ArrayReverseIterator& other) const noexcept { return _ptr > other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator>=(const T* other) const noexcept { return _ptr >= other; } + ASMJIT_INLINE_NODEBUG bool operator>=(const ArrayReverseIterator& other) const noexcept { return _ptr >= other._ptr; } + + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator++() noexcept { _ptr--; return *this; } + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator++(int) noexcept { ArrayReverseIterator prev(*this); _ptr--; return prev; } + + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator--() noexcept { _ptr++; return *this; } + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator--(int) noexcept { ArrayReverseIterator prev(*this); _ptr++; return prev; } + + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator operator+(const Diff& n) noexcept { return ArrayReverseIterator(_ptr -= n); } + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator operator-(const Diff& n) noexcept { return ArrayReverseIterator(_ptr += n); } + + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator+=(const Diff& n) noexcept { _ptr -= n; return *this; } + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator-=(const Diff& n) noexcept { _ptr += n; return *this; } + + ASMJIT_INLINE_NODEBUG constexpr T& operator*() const noexcept { return _ptr[-1]; } + ASMJIT_INLINE_NODEBUG constexpr T* operator->() const noexcept { return &_ptr[-1]; } + + template ASMJIT_INLINE_NODEBUG T& operator[](const Diff& n) noexcept { return *(_ptr - n - 1); } + + ASMJIT_INLINE_NODEBUG operator T*() const noexcept { return _ptr; } + + //! \} +}; + +// Support - Array +// =============== + +//! Array type, similar to std::array, with the possibility to use enums in operator[]. +//! +//! \note The array has C semantics - the elements in the array are not initialized. +template +struct Array { + //! \name Members + //! \{ + + //! The underlying array data, use \ref data() to access it. + T _data[N]; + + //! \} + + //! \cond + // std compatibility. + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + typedef value_type& reference; + typedef const value_type& const_reference; + + typedef value_type* pointer; + typedef const value_type* const_pointer; + + typedef pointer iterator; + typedef const_pointer const_iterator; + //! \endcond + + //! \name Overloaded Operators + //! \{ + + template + inline T& operator[](const Index& index) noexcept { + typedef typename Internal::StdInt::Type U; + ASMJIT_ASSERT(U(index) < N); + return _data[U(index)]; + } + + template + inline const T& operator[](const Index& index) const noexcept { + typedef typename Internal::StdInt::Type U; + ASMJIT_ASSERT(U(index) < N); + return _data[U(index)]; + } + + inline bool operator==(const Array& other) const noexcept { + for (size_t i = 0; i < N; i++) + if (_data[i] != other._data[i]) + return false; + return true; + } + + inline bool operator!=(const Array& other) const noexcept { + return !operator==(other); + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return false; } + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return N; } + + ASMJIT_INLINE_NODEBUG T* data() noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const T* data() const noexcept { return _data; } + + ASMJIT_INLINE_NODEBUG T& front() noexcept { return _data[0]; } + ASMJIT_INLINE_NODEBUG const T& front() const noexcept { return _data[0]; } + + ASMJIT_INLINE_NODEBUG T& back() noexcept { return _data[N - 1]; } + ASMJIT_INLINE_NODEBUG const T& back() const noexcept { return _data[N - 1]; } + + ASMJIT_INLINE_NODEBUG T* begin() noexcept { return _data; } + ASMJIT_INLINE_NODEBUG T* end() noexcept { return _data + N; } + + ASMJIT_INLINE_NODEBUG const T* begin() const noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const T* end() const noexcept { return _data + N; } + + ASMJIT_INLINE_NODEBUG const T* cbegin() const noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const T* cend() const noexcept { return _data + N; } + + //! \} + + //! \name Utilities + //! \{ + + inline void swap(Array& other) noexcept { + for (size_t i = 0; i < N; i++) + std::swap(_data[i], other._data[i]); + } + + inline void fill(const T& value) noexcept { + for (size_t i = 0; i < N; i++) + _data[i] = value; + } + + inline void copyFrom(const Array& other) noexcept { + for (size_t i = 0; i < N; i++) + _data[i] = other._data[i]; + } + + template + inline void combine(const Array& other) noexcept { + for (size_t i = 0; i < N; i++) + _data[i] = Operator::op(_data[i], other._data[i]); + } + + template + inline T aggregate(T initialValue = T()) const noexcept { + T value = initialValue; + for (size_t i = 0; i < N; i++) + value = Operator::op(value, _data[i]); + return value; + } + + template + inline void forEach(Fn&& fn) noexcept { + for (size_t i = 0; i < N; i++) + fn(_data[i]); + } + //! \} +}; + +// Support::Temporary +// ================== + +//! Used to pass a temporary buffer to: +//! +//! - Containers that use user-passed buffer as an initial storage (still can grow). +//! - Zone allocator that would use the temporary buffer as a first block. +struct Temporary { + //! \name Members + //! \{ + + void* _data; + size_t _size; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG constexpr Temporary(const Temporary& other) noexcept = default; + ASMJIT_INLINE_NODEBUG constexpr Temporary(void* data, size_t size) noexcept + : _data(data), + _size(size) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Temporary& operator=(const Temporary& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the data storage. + template + ASMJIT_INLINE_NODEBUG constexpr T* data() const noexcept { return static_cast(_data); } + //! Returns the data storage size in bytes. + ASMJIT_INLINE_NODEBUG constexpr size_t size() const noexcept { return _size; } + + //! \} +}; + +} // {Support} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_SUPPORT_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/target.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/target.h new file mode 100644 index 0000000000000000000000000000000000000000..4365be1401169e563c9216b8e9595649fd5a811f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/target.h @@ -0,0 +1,59 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_TARGET_H_INCLUDED +#define ASMJIT_CORE_TARGET_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/cpuinfo.h" +#include "../core/func.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Target is an abstract class that describes a machine code target. +class ASMJIT_VIRTAPI Target { +public: + ASMJIT_BASE_CLASS(Target) + ASMJIT_NONCOPYABLE(Target) + + //! Target environment information. + Environment _environment; + //! Target CPU features. + CpuFeatures _cpuFeatures; + + //! \name Construction & Destruction + //! \{ + + //! Creates a `Target` instance. + ASMJIT_API Target() noexcept; + //! Destroys the `Target` instance. + ASMJIT_API virtual ~Target() noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns target's environment. + ASMJIT_INLINE_NODEBUG const Environment& environment() const noexcept { return _environment; } + //! Returns the target architecture. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _environment.arch(); } + //! Returns the target sub-architecture. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return _environment.subArch(); } + + //! Returns target CPU features. + ASMJIT_INLINE_NODEBUG const CpuFeatures& cpuFeatures() const noexcept { return _cpuFeatures; } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_TARGET_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/type.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/type.h new file mode 100644 index 0000000000000000000000000000000000000000..415fe0a421dcda55cc3216ab26b33f5080c87104 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/type.h @@ -0,0 +1,443 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_TYPE_H_INCLUDED +#define ASMJIT_CORE_TYPE_H_INCLUDED + +#include "../core/globals.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Type identifier provides a minimalist type system used across AsmJit library. +//! +//! This is an additional information that can be used to describe a value-type of physical or virtual register. It's +//! used mostly by BaseCompiler to describe register representation (the group of data stored in the register and the +//! width used) and it's also used by APIs that allow to describe and work with function signatures. +enum class TypeId : uint8_t { + //! Void type. + kVoid = 0, + + _kBaseStart = 32, + _kBaseEnd = 44, + + _kIntStart = 32, + _kIntEnd = 41, + + //! Abstract signed integer type that has a native size. + kIntPtr = 32, + //! Abstract unsigned integer type that has a native size. + kUIntPtr = 33, + + //! 8-bit signed integer type. + kInt8 = 34, + //! 8-bit unsigned integer type. + kUInt8 = 35, + //! 16-bit signed integer type. + kInt16 = 36, + //! 16-bit unsigned integer type. + kUInt16 = 37, + //! 32-bit signed integer type. + kInt32 = 38, + //! 32-bit unsigned integer type. + kUInt32 = 39, + //! 64-bit signed integer type. + kInt64 = 40, + //! 64-bit unsigned integer type. + kUInt64 = 41, + + _kFloatStart = 42, + _kFloatEnd = 44, + + //! 32-bit floating point type. + kFloat32 = 42, + //! 64-bit floating point type. + kFloat64 = 43, + //! 80-bit floating point type. + kFloat80 = 44, + + _kMaskStart = 45, + _kMaskEnd = 48, + + //! 8-bit opmask register (K). + kMask8 = 45, + //! 16-bit opmask register (K). + kMask16 = 46, + //! 32-bit opmask register (K). + kMask32 = 47, + //! 64-bit opmask register (K). + kMask64 = 48, + + _kMmxStart = 49, + _kMmxEnd = 50, + + //! 64-bit MMX register only used for 32 bits. + kMmx32 = 49, + //! 64-bit MMX register. + kMmx64 = 50, + + _kVec32Start = 51, + _kVec32End = 60, + + kInt8x4 = 51, + kUInt8x4 = 52, + kInt16x2 = 53, + kUInt16x2 = 54, + kInt32x1 = 55, + kUInt32x1 = 56, + kFloat32x1 = 59, + + _kVec64Start = 61, + _kVec64End = 70, + + kInt8x8 = 61, + kUInt8x8 = 62, + kInt16x4 = 63, + kUInt16x4 = 64, + kInt32x2 = 65, + kUInt32x2 = 66, + kInt64x1 = 67, + kUInt64x1 = 68, + kFloat32x2 = 69, + kFloat64x1 = 70, + + _kVec128Start = 71, + _kVec128End = 80, + + kInt8x16 = 71, + kUInt8x16 = 72, + kInt16x8 = 73, + kUInt16x8 = 74, + kInt32x4 = 75, + kUInt32x4 = 76, + kInt64x2 = 77, + kUInt64x2 = 78, + kFloat32x4 = 79, + kFloat64x2 = 80, + + _kVec256Start = 81, + _kVec256End = 90, + + kInt8x32 = 81, + kUInt8x32 = 82, + kInt16x16 = 83, + kUInt16x16 = 84, + kInt32x8 = 85, + kUInt32x8 = 86, + kInt64x4 = 87, + kUInt64x4 = 88, + kFloat32x8 = 89, + kFloat64x4 = 90, + + _kVec512Start = 91, + _kVec512End = 100, + + kInt8x64 = 91, + kUInt8x64 = 92, + kInt16x32 = 93, + kUInt16x32 = 94, + kInt32x16 = 95, + kUInt32x16 = 96, + kInt64x8 = 97, + kUInt64x8 = 98, + kFloat32x16 = 99, + kFloat64x8 = 100, + + kLastAssigned = kFloat64x8, + + kMaxValue = 255 +}; +ASMJIT_DEFINE_ENUM_COMPARE(TypeId) + +//! Type identifier utilities. +namespace TypeUtils { + +struct TypeData { + TypeId scalarOf[uint32_t(TypeId::kMaxValue) + 1]; + uint8_t sizeOf[uint32_t(TypeId::kMaxValue) + 1]; +}; +ASMJIT_VARAPI const TypeData _typeData; + +//! Returns the scalar type of `typeId`. +static ASMJIT_INLINE_NODEBUG TypeId scalarOf(TypeId typeId) noexcept { return _typeData.scalarOf[uint32_t(typeId)]; } + +//! Returns the size [in bytes] of `typeId`. +static ASMJIT_INLINE_NODEBUG uint32_t sizeOf(TypeId typeId) noexcept { return _typeData.sizeOf[uint32_t(typeId)]; } + +//! Tests whether a given type `typeId` is between `a` and `b`. +static ASMJIT_INLINE_NODEBUG constexpr bool isBetween(TypeId typeId, TypeId a, TypeId b) noexcept { + return Support::isBetween(uint32_t(typeId), uint32_t(a), uint32_t(b)); +} + +//! Tests whether a given type `typeId` is \ref TypeId::kVoid. +static ASMJIT_INLINE_NODEBUG constexpr bool isVoid(TypeId typeId) noexcept { return typeId == TypeId::kVoid; } +//! Tests whether a given type `typeId` is a valid non-void type. +static ASMJIT_INLINE_NODEBUG constexpr bool isValid(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kIntStart, TypeId::_kVec512End); } +//! Tests whether a given type `typeId` is scalar (has no vector part). +static ASMJIT_INLINE_NODEBUG constexpr bool isScalar(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kBaseStart, TypeId::_kBaseEnd); } +//! Tests whether a given type `typeId` is abstract, which means that its size depends on register size. +static ASMJIT_INLINE_NODEBUG constexpr bool isAbstract(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kIntPtr, TypeId::kUIntPtr); } + +//! Tests whether a given type is a scalar integer (signed or unsigned) of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isInt(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kIntStart, TypeId::_kIntEnd); } +//! Tests whether a given type is a scalar 8-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt8(TypeId typeId) noexcept { return typeId == TypeId::kInt8; } +//! Tests whether a given type is a scalar 8-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt8(TypeId typeId) noexcept { return typeId == TypeId::kUInt8; } +//! Tests whether a given type is a scalar 16-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt16(TypeId typeId) noexcept { return typeId == TypeId::kInt16; } +//! Tests whether a given type is a scalar 16-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt16(TypeId typeId) noexcept { return typeId == TypeId::kUInt16; } +//! Tests whether a given type is a scalar 32-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt32(TypeId typeId) noexcept { return typeId == TypeId::kInt32; } +//! Tests whether a given type is a scalar 32-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt32(TypeId typeId) noexcept { return typeId == TypeId::kUInt32; } +//! Tests whether a given type is a scalar 64-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt64(TypeId typeId) noexcept { return typeId == TypeId::kInt64; } +//! Tests whether a given type is a scalar 64-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt64(TypeId typeId) noexcept { return typeId == TypeId::kUInt64; } + +//! Tests whether a given type is an 8-bit general purpose register representing either signed or unsigned 8-bit integer. +static ASMJIT_INLINE_NODEBUG constexpr bool isGp8(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt8, TypeId::kUInt8); } +//! Tests whether a given type is a 16-bit general purpose register representing either signed or unsigned 16-bit integer +static ASMJIT_INLINE_NODEBUG constexpr bool isGp16(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt16, TypeId::kUInt16); } +//! Tests whether a given type is a 32-bit general purpose register representing either signed or unsigned 32-bit integer +static ASMJIT_INLINE_NODEBUG constexpr bool isGp32(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt32, TypeId::kUInt32); } +//! Tests whether a given type is a 64-bit general purpose register representing either signed or unsigned 64-bit integer +static ASMJIT_INLINE_NODEBUG constexpr bool isGp64(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt64, TypeId::kUInt64); } + +//! Tests whether a given type is a scalar floating point of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kFloatStart, TypeId::_kFloatEnd); } +//! Tests whether a given type is a scalar 32-bit float. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat32(TypeId typeId) noexcept { return typeId == TypeId::kFloat32; } +//! Tests whether a given type is a scalar 64-bit float. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat64(TypeId typeId) noexcept { return typeId == TypeId::kFloat64; } +//! Tests whether a given type is a scalar 80-bit float. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat80(TypeId typeId) noexcept { return typeId == TypeId::kFloat80; } + +//! Tests whether a given type is a mask register of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kMaskStart, TypeId::_kMaskEnd); } +//! Tests whether a given type is an 8-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask8(TypeId typeId) noexcept { return typeId == TypeId::kMask8; } +//! Tests whether a given type is an 16-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask16(TypeId typeId) noexcept { return typeId == TypeId::kMask16; } +//! Tests whether a given type is an 32-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask32(TypeId typeId) noexcept { return typeId == TypeId::kMask32; } +//! Tests whether a given type is an 64-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask64(TypeId typeId) noexcept { return typeId == TypeId::kMask64; } + +//! Tests whether a given type is an MMX register. +//! +//! \note MMX functionality is in general deprecated on X86 architecture. AsmJit provides it just for completeness. +static ASMJIT_INLINE_NODEBUG constexpr bool isMmx(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kMmxStart, TypeId::_kMmxEnd); } +//! Tests whether a given type is an MMX register, which only uses the low 32 bits of data (only specific cases). +//! +//! \note MMX functionality is in general deprecated on X86 architecture. AsmJit provides it just for completeness. +static ASMJIT_INLINE_NODEBUG constexpr bool isMmx32(TypeId typeId) noexcept { return typeId == TypeId::kMmx32; } +//! Tests whether a given type is an MMX register, which uses 64 bits of data (default). +//! +//! \note MMX functionality is in general deprecated on X86 architecture. AsmJit provides it just for completeness. +static ASMJIT_INLINE_NODEBUG constexpr bool isMmx64(TypeId typeId) noexcept { return typeId == TypeId::kMmx64; } + +//! Tests whether a given type is a vector register of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec32Start, TypeId::_kVec512End); } +//! Tests whether a given type is a 32-bit or 32-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec32(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec32Start, TypeId::_kVec32End); } +//! Tests whether a given type is a 64-bit or 64-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec64(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec64Start, TypeId::_kVec64End); } +//! Tests whether a given type is a 128-bit or 128-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec128(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec128Start, TypeId::_kVec128End); } +//! Tests whether a given type is a 256-bit or 256-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec256(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec256Start, TypeId::_kVec256End); } +//! Tests whether a given type is a 512-bit or 512-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec512(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec512Start, TypeId::_kVec512End); } + +//! \cond +enum TypeCategory : uint32_t { + kTypeCategoryUnknown = 0, + kTypeCategoryEnum = 1, + kTypeCategoryIntegral = 2, + kTypeCategoryFloatingPoint = 3, + kTypeCategoryFunction = 4 +}; + +template +struct TypeIdOfT_ByCategory {}; // Fails if not specialized. + +template +struct TypeIdOfT_ByCategory { + enum : uint32_t { + kTypeId = uint32_t( + (sizeof(T) == 1 && std::is_signed::value) ? TypeId::kInt8 : + (sizeof(T) == 1 && !std::is_signed::value) ? TypeId::kUInt8 : + (sizeof(T) == 2 && std::is_signed::value) ? TypeId::kInt16 : + (sizeof(T) == 2 && !std::is_signed::value) ? TypeId::kUInt16 : + (sizeof(T) == 4 && std::is_signed::value) ? TypeId::kInt32 : + (sizeof(T) == 4 && !std::is_signed::value) ? TypeId::kUInt32 : + (sizeof(T) == 8 && std::is_signed::value) ? TypeId::kInt64 : + (sizeof(T) == 8 && !std::is_signed::value) ? TypeId::kUInt64 : TypeId::kVoid) + }; +}; + +template +struct TypeIdOfT_ByCategory { + enum : uint32_t { + kTypeId = uint32_t( + (sizeof(T) == 4 ) ? TypeId::kFloat32 : + (sizeof(T) == 8 ) ? TypeId::kFloat64 : + (sizeof(T) >= 10) ? TypeId::kFloat80 : TypeId::kVoid) + }; +}; + +template +struct TypeIdOfT_ByCategory + : public TypeIdOfT_ByCategory::type, kTypeCategoryIntegral> {}; + +template +struct TypeIdOfT_ByCategory { + enum : uint32_t { + kTypeId = uint32_t(TypeId::kUIntPtr) + }; +}; +//! \endcond + +//! TypeIdOfT<> template allows to get a TypeId from a C++ type `T`. +#ifdef _DOXYGEN +template +struct TypeIdOfT { + //! TypeId of C++ type `T`. + static constexpr TypeId kTypeId = _TypeIdDeducedAtCompileTime_; +}; +#else +template +struct TypeIdOfT + : public TypeIdOfT_ByCategory::value ? kTypeCategoryEnum : + std::is_integral::value ? kTypeCategoryIntegral : + std::is_floating_point::value ? kTypeCategoryFloatingPoint : + std::is_function::value ? kTypeCategoryFunction : kTypeCategoryUnknown> {}; +#endif + +//! \cond +template +struct TypeIdOfT { + enum : uint32_t { + kTypeId = uint32_t(TypeId::kUIntPtr) + }; +}; + +template +struct TypeIdOfT { + enum : uint32_t { + kTypeId = uint32_t(TypeId::kUIntPtr) + }; +}; +//! \endcond + +//! Returns a corresponding \ref TypeId of `T` type. +template +static ASMJIT_INLINE_NODEBUG constexpr TypeId typeIdOfT() noexcept { return TypeId(TypeIdOfT::kTypeId); } + +//! Returns offset needed to convert a `kIntPtr` and `kUIntPtr` TypeId into a type that matches `registerSize` +//! (general-purpose register size). If you find such TypeId it's then only about adding the offset to it. +//! +//! For example: +//! +//! ``` +//! uint32_t registerSize = /* 4 or 8 */; +//! uint32_t deabstractDelta = TypeUtils::deabstractDeltaOfSize(registerSize); +//! +//! TypeId typeId = 'some type-id'; +//! +//! // Normalize some typeId into a non-abstract typeId. +//! if (TypeUtils::isAbstract(typeId)) typeId += deabstractDelta; +//! +//! // The same, but by using TypeUtils::deabstract() function. +//! typeId = TypeUtils::deabstract(typeId, deabstractDelta); +//! ``` +static ASMJIT_INLINE_NODEBUG constexpr uint32_t deabstractDeltaOfSize(uint32_t registerSize) noexcept { + return registerSize >= 8 ? uint32_t(TypeId::kInt64) - uint32_t(TypeId::kIntPtr) + : uint32_t(TypeId::kInt32) - uint32_t(TypeId::kIntPtr); +} + +//! Deabstracts a given `typeId` into a native type by using `deabstractDelta`, which was previously +//! calculated by calling \ref deabstractDeltaOfSize() with a target native register size. +static ASMJIT_INLINE_NODEBUG constexpr TypeId deabstract(TypeId typeId, uint32_t deabstractDelta) noexcept { + return isAbstract(typeId) ? TypeId(uint32_t(typeId) + deabstractDelta) : typeId; +} + +static ASMJIT_INLINE_NODEBUG constexpr TypeId scalarToVector(TypeId scalarTypeId, TypeId vecStartId) noexcept { + return TypeId(uint32_t(vecStartId) + uint32_t(scalarTypeId) - uint32_t(TypeId::kInt8)); +} + +} // {TypeUtils} + +//! Provides type identifiers that can be used in templates instead of native types. +namespace Type { + +//! bool as C++ type-name. +struct Bool {}; +//! int8_t as C++ type-name. +struct Int8 {}; +//! uint8_t as C++ type-name. +struct UInt8 {}; +//! int16_t as C++ type-name. +struct Int16 {}; +//! uint16_t as C++ type-name. +struct UInt16 {}; +//! int32_t as C++ type-name. +struct Int32 {}; +//! uint32_t as C++ type-name. +struct UInt32 {}; +//! int64_t as C++ type-name. +struct Int64 {}; +//! uint64_t as C++ type-name. +struct UInt64 {}; +//! intptr_t as C++ type-name. +struct IntPtr {}; +//! uintptr_t as C++ type-name. +struct UIntPtr {}; +//! float as C++ type-name. +struct Float32 {}; +//! double as C++ type-name. +struct Float64 {}; + +} // {Type} + +//! \cond +#define ASMJIT_DEFINE_TYPE_ID(T, TYPE_ID) \ +namespace TypeUtils { \ + template<> \ + struct TypeIdOfT { \ + enum : uint32_t { \ + kTypeId = uint32_t(TYPE_ID) \ + }; \ + }; \ +} + +ASMJIT_DEFINE_TYPE_ID(void , TypeId::kVoid); +ASMJIT_DEFINE_TYPE_ID(Type::Bool , TypeId::kUInt8); +ASMJIT_DEFINE_TYPE_ID(Type::Int8 , TypeId::kInt8); +ASMJIT_DEFINE_TYPE_ID(Type::UInt8 , TypeId::kUInt8); +ASMJIT_DEFINE_TYPE_ID(Type::Int16 , TypeId::kInt16); +ASMJIT_DEFINE_TYPE_ID(Type::UInt16 , TypeId::kUInt16); +ASMJIT_DEFINE_TYPE_ID(Type::Int32 , TypeId::kInt32); +ASMJIT_DEFINE_TYPE_ID(Type::UInt32 , TypeId::kUInt32); +ASMJIT_DEFINE_TYPE_ID(Type::Int64 , TypeId::kInt64); +ASMJIT_DEFINE_TYPE_ID(Type::UInt64 , TypeId::kUInt64); +ASMJIT_DEFINE_TYPE_ID(Type::IntPtr , TypeId::kIntPtr); +ASMJIT_DEFINE_TYPE_ID(Type::UIntPtr, TypeId::kUIntPtr); +ASMJIT_DEFINE_TYPE_ID(Type::Float32, TypeId::kFloat32); +ASMJIT_DEFINE_TYPE_ID(Type::Float64, TypeId::kFloat64); +//! \endcond + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_TYPE_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/virtmem.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/virtmem.h new file mode 100644 index 0000000000000000000000000000000000000000..17996dcb07e9d8440e8ddc1de9ff9ce2d8b81a7f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/virtmem.h @@ -0,0 +1,327 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_VIRTMEM_H_INCLUDED +#define ASMJIT_CORE_VIRTMEM_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_JIT + +#include "../core/globals.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_virtual_memory +//! \{ + +//! Virtual memory management. +namespace VirtMem { + +//! Describes whether instruction cache should be flushed after a write operation. +enum class CachePolicy : uint32_t { + //! Default policy. + //! + //! In some places this would mean `kFlushAfterWrite` and in some places it would mean `kNeverFlush`. + //! For example if it's known that an address has never been used before to execute code. + kDefault = 0, + + //! Flush instruction cache after a write operation. + kFlushAfterWrite = 1, + + //! Avoid flushing instruction cache after a write operation. + kNeverFlush = 2 +}; + +//! Flushes instruction cache in the given region. +//! +//! Only useful on non-x86 architectures, however, it's a good practice to call it on any platform to make your +//! code more portable. +ASMJIT_API void flushInstructionCache(void* p, size_t size) noexcept; + +//! Virtual memory information. +struct Info { + //! Virtual memory page size. + uint32_t pageSize; + //! Virtual memory page granularity. + uint32_t pageGranularity; +}; + +//! Returns virtual memory information, see `VirtMem::Info` for more details. +ASMJIT_API Info info() noexcept; + +//! Returns the size of the smallest large page supported. +//! +//! AsmJit only uses the smallest large page at the moment as these are usually perfectly sized for executable +//! memory allocation (standard size is 2MB, but different sizes are possible). +//! +//! Returns either the detected large page size or 0, if large page support is either not supported by AsmJit +//! or not accessible to the process. +ASMJIT_API size_t largePageSize() noexcept; + +//! Virtual memory access and mmap-specific flags. +enum class MemoryFlags : uint32_t { + //! No flags. + kNone = 0, + + //! Memory is readable. + kAccessRead = 0x00000001u, + + //! Memory is writable. + kAccessWrite = 0x00000002u, + + //! Memory is executable. + kAccessExecute = 0x00000004u, + + //! A combination of \ref kAccessRead and \ref kAccessWrite. + kAccessReadWrite = kAccessRead | kAccessWrite, + + //! A combination of \ref kAccessRead, \ref kAccessWrite. + kAccessRW = kAccessRead | kAccessWrite, + + //! A combination of \ref kAccessRead and \ref kAccessExecute. + kAccessRX = kAccessRead | kAccessExecute, + + //! A combination of \ref kAccessRead, \ref kAccessWrite, and \ref kAccessExecute. + kAccessRWX = kAccessRead | kAccessWrite | kAccessExecute, + + //! Use a `MAP_JIT` flag available on Apple platforms (introduced by Mojave), which allows JIT code to be + //! executed in a MAC bundle. + //! + //! This flag may be turned on by the allocator if there is no other way of allocating executable memory. + //! + //! \note This flag can only be used with \ref VirtMem::alloc(), `MAP_JIT` only works on OSX and not on iOS. + //! When a process uses `fork()` the child process has no access to the pages mapped with `MAP_JIT`. + kMMapEnableMapJit = 0x00000010u, + + //! Pass `PROT_MAX(PROT_READ)` or `PROT_MPROTECT(PROT_READ)` to `mmap()` on platforms that support it. + //! + //! This flag allows to set a "maximum access" that the memory page can get during its lifetime. Use + //! \ref VirtMem::protect() to change the access flags. + //! + //! \note This flag can only be used with \ref VirtMem::alloc() and \ref VirtMem::allocDualMapping(). + //! However \ref VirtMem::allocDualMapping() may automatically use this if \ref kAccessRead is used. + kMMapMaxAccessRead = 0x00000020u, + + //! Pass `PROT_MAX(PROT_WRITE)` or `PROT_MPROTECT(PROT_WRITE)` to `mmap()` on platforms that support it. + //! + //! This flag allows to set a "maximum access" that the memory page can get during its lifetime. Use + //! \ref VirtMem::protect() to change the access flags. + //! + //! \note This flag can only be used with \ref VirtMem::alloc() and \ref VirtMem::allocDualMapping(). + //! However \ref VirtMem::allocDualMapping() may automatically use this if \ref kAccessWrite is used. + kMMapMaxAccessWrite = 0x00000040u, + + //! Pass `PROT_MAX(PROT_EXEC)` or `PROT_MPROTECT(PROT_EXEC)` to `mmap()` on platforms that support it. + //! + //! This flag allows to set a "maximum access" that the memory page can get during its lifetime. Use + //! \ref VirtMem::protect() to change the access flags. + //! + //! \note This flag can only be used with \ref VirtMem::alloc() and \ref VirtMem::allocDualMapping(). + //! However \ref VirtMem::allocDualMapping() may automatically use this if \ref kAccessExecute is used. + kMMapMaxAccessExecute = 0x00000080u, + + //! A combination of \ref kMMapMaxAccessRead and \ref kMMapMaxAccessWrite. + kMMapMaxAccessReadWrite = kMMapMaxAccessRead | kMMapMaxAccessWrite, + + //! A combination of \ref kMMapMaxAccessRead and \ref kMMapMaxAccessWrite. + kMMapMaxAccessRW = kMMapMaxAccessRead | kMMapMaxAccessWrite, + + //! A combination of \ref kMMapMaxAccessRead and \ref kMMapMaxAccessExecute. + kMMapMaxAccessRX = kMMapMaxAccessRead | kMMapMaxAccessExecute, + + //! A combination of \ref kMMapMaxAccessRead, \ref kMMapMaxAccessWrite, \ref kMMapMaxAccessExecute. + kMMapMaxAccessRWX = kMMapMaxAccessRead | kMMapMaxAccessWrite | kMMapMaxAccessExecute, + + //! Use `MAP_SHARED` when calling mmap(). + //! + //! \note In some cases `MAP_SHARED` may be set automatically. For example, some dual mapping implementations must + //! use `MAP_SHARED` instead of `MAP_PRIVATE` to ensure that the OS would not apply copy on write on RW page, which + //! would cause RX page not having the updated content. + kMapShared = 0x00000100u, + + //! Request large memory mapped pages. + //! + //! \remarks If this option is used and large page(s) cannot be mapped, the allocation will fail. Fallback to + //! regular pages must be done by the user in this case. Higher level API such as \ref JitAllocator provides an + //! additional mechanism to allocate regular page(s) when large page(s) allocation fails. + kMMapLargePages = 0x00000200u, + + //! Not an access flag, only used by `allocDualMapping()` to override the default allocation strategy to always use + //! a 'tmp' directory instead of "/dev/shm" (on POSIX platforms). Please note that this flag will be ignored if the + //! operating system allows to allocate an executable memory by a different API than `open()` or `shm_open()`. For + //! example on Linux `memfd_create()` is preferred and on BSDs `shm_open(SHM_ANON, ...)` is used if SHM_ANON is + //! defined. + //! + //! \note This flag can only be used with \ref VirtMem::alloc(). + kMappingPreferTmp = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(MemoryFlags) + +//! Allocates virtual memory by either using `mmap()` (POSIX) or `VirtualAlloc()` (Windows). +//! +//! \note `size` should be aligned to page size, use \ref VirtMem::info() to obtain it. Invalid size will not be +//! corrected by the implementation and the allocation would not succeed in such case. +ASMJIT_API Error alloc(void** p, size_t size, MemoryFlags flags) noexcept; + +//! Releases virtual memory previously allocated by \ref VirtMem::alloc(). +//! +//! \note The size must be the same as used by \ref VirtMem::alloc(). If the size is not the same value the call +//! will fail on any POSIX system, but pass on Windows, because it's implemented differently. +ASMJIT_API Error release(void* p, size_t size) noexcept; + +//! A cross-platform wrapper around `mprotect()` (POSIX) and `VirtualProtect()` (Windows). +ASMJIT_API Error protect(void* p, size_t size, MemoryFlags flags) noexcept; + +//! Dual memory mapping used to map an anonymous memory into two memory regions where one region is read-only, but +//! executable, and the second region is read+write, but not executable. See \ref VirtMem::allocDualMapping() for +//! more details. +struct DualMapping { + //! Pointer to data with 'Read+Execute' access (this memory is not writable). + void* rx; + //! Pointer to data with 'Read+Write' access (this memory is not executable). + void* rw; +}; + +//! Allocates virtual memory and creates two views of it where the first view has no write access. This is an addition +//! to the API that should be used in cases in which the operating system either enforces W^X security policy or the +//! application wants to use this policy by default to improve security and prevent an accidental (or purposed) +//! self-modifying code. +//! +//! The memory returned in the `dm` are two independent mappings of the same shared memory region. You must use +//! \ref VirtMem::releaseDualMapping() to release it when it's no longer needed. Never use `VirtMem::release()` to +//! release the memory returned by `allocDualMapping()` as that would fail on Windows. +//! +//! \remarks Both pointers in `dm` would be set to `nullptr` if the function fails. +ASMJIT_API Error allocDualMapping(DualMapping* dm, size_t size, MemoryFlags flags) noexcept; + +//! Releases virtual memory mapping previously allocated by \ref VirtMem::allocDualMapping(). +//! +//! \remarks Both pointers in `dm` would be set to `nullptr` if the function succeeds. +ASMJIT_API Error releaseDualMapping(DualMapping* dm, size_t size) noexcept; + +//! Hardened runtime flags. +enum class HardenedRuntimeFlags : uint32_t { + //! No flags. + kNone = 0, + + //! Hardened runtime is enabled - it's not possible to have "Write & Execute" memory protection. The runtime + //! enforces W^X (either write or execute). + //! + //! \note If the runtime is hardened it means that an operating system specific protection is used. For example + //! on Apple OSX it's possible to allocate memory with MAP_JIT flag and then use `pthread_jit_write_protect_np()` + //! to temporarily swap access permissions for the current thread. Dual mapping is also a possibility on X86/X64 + //! architecture. + kEnabled = 0x00000001u, + + //! Read+Write+Execute can only be allocated with MAP_JIT flag (Apple specific, only available on Apple platforms). + kMapJit = 0x00000002u, + + //! Read+Write+Execute can be allocated with dual mapping approach (one region with RW and the other with RX). + kDualMapping = 0x00000004u +}; +ASMJIT_DEFINE_ENUM_FLAGS(HardenedRuntimeFlags) + +//! Hardened runtime information. +struct HardenedRuntimeInfo { + //! \name Members + //! \{ + + //! Hardened runtime flags. + HardenedRuntimeFlags flags; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the hardened runtime `flag` is set. + ASMJIT_INLINE_NODEBUG bool hasFlag(HardenedRuntimeFlags flag) const noexcept { return Support::test(flags, flag); } + + //! \} +}; + +//! Returns runtime features provided by the OS. +ASMJIT_API HardenedRuntimeInfo hardenedRuntimeInfo() noexcept; + +//! Values that can be used with `protectJitMemory()` function. +enum class ProtectJitAccess : uint32_t { + //! Protect JIT memory with Read+Write permissions. + kReadWrite = 0, + //! Protect JIT memory with Read+Execute permissions. + kReadExecute = 1 +}; + +//! Protects access of memory mapped with MAP_JIT flag for the current thread. +//! +//! \note This feature is only available on Apple hardware (AArch64) at the moment and uses a non-portable +//! `pthread_jit_write_protect_np()` call when available. +//! +//! This function must be called before and after a memory mapped with MAP_JIT flag is modified. Example: +//! +//! ``` +//! void* codePtr = ...; +//! size_t codeSize = ...; +//! +//! VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadWrite); +//! memcpy(codePtr, source, codeSize); +//! VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadExecute); +//! VirtMem::flushInstructionCache(codePtr, codeSize); +//! ``` +//! +//! See \ref ProtectJitReadWriteScope, which makes it simpler than the code above. +ASMJIT_API void protectJitMemory(ProtectJitAccess access) noexcept; + +//! JIT protection scope that prepares the given memory block to be written to in the current thread. +//! +//! It calls `VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadWrite)` at construction time and +//! `VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadExecute)` combined with `flushInstructionCache()` +//! in destructor. The purpose of this class is to make writing to JIT memory easier. +class ProtectJitReadWriteScope { +public: + ASMJIT_NONCOPYABLE(ProtectJitReadWriteScope) + + //! \name Members + //! \{ + + void* _rxPtr; + size_t _size; + CachePolicy _policy; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Makes the given memory block RW protected. + ASMJIT_FORCE_INLINE ProtectJitReadWriteScope( + void* rxPtr, + size_t size, + CachePolicy policy = CachePolicy::kDefault) noexcept + : _rxPtr(rxPtr), + _size(size), + _policy(policy) { + protectJitMemory(ProtectJitAccess::kReadWrite); + } + + //! Makes the memory block RX protected again and flushes instruction cache. + ASMJIT_FORCE_INLINE ~ProtectJitReadWriteScope() noexcept { + protectJitMemory(ProtectJitAccess::kReadExecute); + + if (_policy != CachePolicy::kNeverFlush) + flushInstructionCache(_rxPtr, _size); + } + + //! \} +}; + +} // VirtMem + +//! \} + +ASMJIT_END_NAMESPACE + +#endif +#endif // ASMJIT_CORE_VIRTMEM_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zone.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zone.h new file mode 100644 index 0000000000000000000000000000000000000000..b61a3d1292c25d3436f588c9dea8c6b43e74d467 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zone.h @@ -0,0 +1,611 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONE_H_INCLUDED +#define ASMJIT_CORE_ZONE_H_INCLUDED + +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Zone memory. +//! +//! Zone is an incremental memory allocator that allocates memory by simply incrementing a pointer. It allocates +//! blocks of memory by using C's `malloc()`, but divides these blocks into smaller segments requested by calling +//! `Zone::alloc()` and friends. +//! +//! Zone has no function to release the allocated memory. It has to be released all at once by calling `reset()`. +//! If you need a more friendly allocator that also supports `release()`, consider using `Zone` with `ZoneAllocator`. +class Zone { +public: + ASMJIT_NONCOPYABLE(Zone) + + //! \cond INTERNAL + + //! A single block of memory managed by `Zone`. + struct Block { + inline uint8_t* data() const noexcept { + return const_cast(reinterpret_cast(this) + sizeof(*this)); + } + + //! Link to the previous block. + Block* prev; + //! Link to the next block. + Block* next; + //! Size of the block. + size_t size; + }; + + enum Limits : size_t { + kBlockSize = sizeof(Block), + kBlockOverhead = Globals::kAllocOverhead + kBlockSize, + + kMinBlockSize = 64, // The number is ridiculously small, but still possible. + kMaxBlockSize = size_t(1) << (sizeof(size_t) * 8 - 4 - 1), + kMinAlignment = 1, + kMaxAlignment = 64 + }; + + //! Pointer in the current block. + uint8_t* _ptr; + //! End of the current block. + uint8_t* _end; + //! Current block. + Block* _block; + + union { + struct { + //! Default block size. + size_t _blockSize : Support::bitSizeOf() - 4; + //! First block is temporary (ZoneTmp). + size_t _isTemporary : 1; + //! Block alignment (1 << alignment). + size_t _blockAlignmentShift : 3; + }; + size_t _packedData; + }; + + static ASMJIT_API const Block _zeroBlock; + + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Creates a new Zone. + //! + //! The `blockSize` parameter describes the default size of the block. If the `size` parameter passed to `alloc()` + //! is greater than the default size `Zone` will allocate and use a larger block, but it will not change the + //! default `blockSize`. + //! + //! It's not required, but it's good practice to set `blockSize` to a reasonable value that depends on the usage + //! of `Zone`. Greater block sizes are generally safer and perform better than unreasonably low block sizes. + ASMJIT_INLINE_NODEBUG explicit Zone(size_t blockSize, size_t blockAlignment = 1) noexcept { + _init(blockSize, blockAlignment, nullptr); + } + + //! Creates a new Zone with a first block pointing to a `temporary` memory. + ASMJIT_INLINE_NODEBUG Zone(size_t blockSize, size_t blockAlignment, const Support::Temporary& temporary) noexcept { + _init(blockSize, blockAlignment, &temporary); + } + + //! \overload + ASMJIT_INLINE_NODEBUG Zone(size_t blockSize, size_t blockAlignment, const Support::Temporary* temporary) noexcept { + _init(blockSize, blockAlignment, temporary); + } + + //! Moves an existing `Zone`. + //! + //! \note You cannot move an existing `ZoneTmp` as it uses embedded storage. Attempting to move `ZoneTmp` would + //! result in assertion failure in debug mode and undefined behavior in release mode. + inline Zone(Zone&& other) noexcept + : _ptr(other._ptr), + _end(other._end), + _block(other._block), + _packedData(other._packedData) { + ASMJIT_ASSERT(!other.isTemporary()); + other._block = const_cast(&_zeroBlock); + other._ptr = other._block->data(); + other._end = other._block->data(); + } + + //! Destroys the `Zone` instance. + //! + //! This will destroy the `Zone` instance and release all blocks of memory allocated by it. It performs implicit + //! `reset(ResetPolicy::kHard)`. + ASMJIT_INLINE_NODEBUG ~Zone() noexcept { reset(ResetPolicy::kHard); } + + ASMJIT_API void _init(size_t blockSize, size_t blockAlignment, const Support::Temporary* temporary) noexcept; + + //! Resets the `Zone` invalidating all blocks allocated. + //! + //! See `Globals::ResetPolicy` for more details. + ASMJIT_API void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether this `Zone` is actually a `ZoneTmp` that uses temporary memory. + ASMJIT_INLINE_NODEBUG bool isTemporary() const noexcept { return _isTemporary != 0; } + + //! Returns the default block size. + ASMJIT_INLINE_NODEBUG size_t blockSize() const noexcept { return _blockSize; } + //! Returns the default block alignment. + ASMJIT_INLINE_NODEBUG size_t blockAlignment() const noexcept { return size_t(1) << _blockAlignmentShift; } + //! Returns remaining size of the current block. + ASMJIT_INLINE_NODEBUG size_t remainingSize() const noexcept { return (size_t)(_end - _ptr); } + + //! Returns the current zone cursor (dangerous). + //! + //! This is a function that can be used to get exclusive access to the current block's memory buffer. + template + ASMJIT_INLINE_NODEBUG T* ptr() noexcept { return reinterpret_cast(_ptr); } + + //! Returns the end of the current zone block, only useful if you use `ptr()`. + template + ASMJIT_INLINE_NODEBUG T* end() noexcept { return reinterpret_cast(_end); } + + //! Sets the current zone pointer to `ptr` (must be within the current block). + template + inline void setPtr(T* ptr) noexcept { + uint8_t* p = reinterpret_cast(ptr); + ASMJIT_ASSERT(p >= _ptr && p <= _end); + _ptr = p; + } + + //! Sets the end zone pointer to `end` (must be within the current block). + template + inline void setEnd(T* end) noexcept { + uint8_t* p = reinterpret_cast(end); + ASMJIT_ASSERT(p >= _ptr && p <= _end); + _end = p; + } + + //! \} + + //! \name Utilities + //! \{ + + inline void swap(Zone& other) noexcept { + // This could lead to a disaster. + ASMJIT_ASSERT(!this->isTemporary()); + ASMJIT_ASSERT(!other.isTemporary()); + + std::swap(_ptr, other._ptr); + std::swap(_end, other._end); + std::swap(_block, other._block); + std::swap(_packedData, other._packedData); + } + + //! Aligns the current pointer to `alignment`. + ASMJIT_INLINE_NODEBUG void align(size_t alignment) noexcept { + _ptr = Support::min(Support::alignUp(_ptr, alignment), _end); + } + + //! Ensures the remaining size is at least equal or greater than `size`. + //! + //! \note This function doesn't respect any alignment. If you need to ensure there is enough room for an aligned + //! allocation you need to call `align()` before calling `ensure()`. + ASMJIT_INLINE_NODEBUG Error ensure(size_t size) noexcept { + if (size <= remainingSize()) + return kErrorOk; + else + return _alloc(0, 1) ? kErrorOk : DebugUtils::errored(kErrorOutOfMemory); + } + + inline void _assignBlock(Block* block) noexcept { + size_t alignment = blockAlignment(); + _ptr = Support::alignUp(block->data(), alignment); + _end = Support::alignDown(block->data() + block->size, alignment); + _block = block; + } + + inline void _assignZeroBlock() noexcept { + Block* block = const_cast(&_zeroBlock); + _ptr = block->data(); + _end = block->data(); + _block = block; + } + + //! \} + + //! \name Allocation + //! \{ + + //! Allocates the requested memory specified by `size`. + //! + //! Pointer returned is valid until the `Zone` instance is destroyed or reset by calling `reset()`. If you plan to + //! make an instance of C++ from the given pointer use placement `new` and `delete` operators: + //! + //! ``` + //! using namespace asmjit; + //! + //! class Object { ... }; + //! + //! // Create Zone with default block size of approximately 65536 bytes. + //! Zone zone(65536 - Zone::kBlockOverhead); + //! + //! // Create your objects using zone object allocating, for example: + //! Object* obj = static_cast( zone.alloc(sizeof(Object)) ); + //! + //! if (!obj) { + //! // Handle out of memory error. + //! } + //! + //! // Placement `new` and `delete` operators can be used to instantiate it. + //! new(obj) Object(); + //! + //! // ... lifetime of your objects ... + //! + //! // To destroy the instance (if required). + //! obj->~Object(); + //! + //! // Reset or destroy `Zone`. + //! zone.reset(); + //! ``` + inline void* alloc(size_t size) noexcept { + if (ASMJIT_UNLIKELY(size > remainingSize())) + return _alloc(size, 1); + + uint8_t* ptr = _ptr; + _ptr += size; + return static_cast(ptr); + } + + //! Allocates the requested memory specified by `size` and `alignment`. + inline void* alloc(size_t size, size_t alignment) noexcept { + ASMJIT_ASSERT(Support::isPowerOf2(alignment)); + uint8_t* ptr = Support::alignUp(_ptr, alignment); + + if (ptr >= _end || size > (size_t)(_end - ptr)) + return _alloc(size, alignment); + + _ptr = ptr + size; + return static_cast(ptr); + } + + //! Allocates the requested memory specified by `size` without doing any checks. + //! + //! Can only be called if `remainingSize()` returns size at least equal to `size`. + inline void* allocNoCheck(size_t size) noexcept { + ASMJIT_ASSERT(remainingSize() >= size); + + uint8_t* ptr = _ptr; + _ptr += size; + return static_cast(ptr); + } + + //! Allocates the requested memory specified by `size` and `alignment` without doing any checks. + //! + //! Performs the same operation as `Zone::allocNoCheck(size)` with `alignment` applied. + inline void* allocNoCheck(size_t size, size_t alignment) noexcept { + ASMJIT_ASSERT(Support::isPowerOf2(alignment)); + + uint8_t* ptr = Support::alignUp(_ptr, alignment); + ASMJIT_ASSERT(size <= (size_t)(_end - ptr)); + + _ptr = ptr + size; + return static_cast(ptr); + } + + //! Allocates `size` bytes of zeroed memory. See `alloc()` for more details. + ASMJIT_API void* allocZeroed(size_t size, size_t alignment = 1) noexcept; + + //! Like `alloc()`, but the return pointer is casted to `T*`. + template + inline T* allocT(size_t size = sizeof(T), size_t alignment = alignof(T)) noexcept { + return static_cast(alloc(size, alignment)); + } + + //! Like `allocNoCheck()`, but the return pointer is casted to `T*`. + template + inline T* allocNoCheckT(size_t size = sizeof(T), size_t alignment = alignof(T)) noexcept { + return static_cast(allocNoCheck(size, alignment)); + } + + //! Like `allocZeroed()`, but the return pointer is casted to `T*`. + template + inline T* allocZeroedT(size_t size = sizeof(T), size_t alignment = alignof(T)) noexcept { + return static_cast(allocZeroed(size, alignment)); + } + + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT() noexcept { + void* p = alloc(sizeof(T), alignof(T)); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(); + } + + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT(Args&&... args) noexcept { + void* p = alloc(sizeof(T), alignof(T)); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(std::forward(args)...); + } + + //! \cond INTERNAL + //! + //! Internal alloc function used by other inlines. + ASMJIT_API void* _alloc(size_t size, size_t alignment) noexcept; + //! \endcond + + //! Helper to duplicate data. + ASMJIT_API void* dup(const void* data, size_t size, bool nullTerminate = false) noexcept; + + //! Helper to duplicate data. + inline void* dupAligned(const void* data, size_t size, size_t alignment, bool nullTerminate = false) noexcept { + align(alignment); + return dup(data, size, nullTerminate); + } + + //! Helper to duplicate a formatted string, maximum size is 256 bytes. + ASMJIT_API char* sformat(const char* str, ...) noexcept; + + //! \} +}; + +//! \ref Zone with `N` bytes of a static storage, used for the initial block. +//! +//! Temporary zones are used in cases where it's known that some memory will be required, but in many cases it won't +//! exceed N bytes, so the whole operation can be performed without a dynamic memory allocation. +template +class ZoneTmp : public Zone { +public: + ASMJIT_NONCOPYABLE(ZoneTmp) + + //! Temporary storage, embedded after \ref Zone. + struct Storage { + char data[N]; + } _storage; + + //! Creates a temporary zone. Dynamic block size is specified by `blockSize`. + inline explicit ZoneTmp(size_t blockSize, size_t blockAlignment = 1) noexcept + : Zone(blockSize, blockAlignment, Support::Temporary(_storage.data, N)) {} +}; + +//! Zone-based memory allocator that uses an existing `Zone` and provides a `release()` functionality on top of it. +//! It uses `Zone` only for chunks that can be pooled, and uses libc `malloc()` for chunks that are large. +//! +//! The advantage of ZoneAllocator is that it can allocate small chunks of memory really fast, and these chunks, +//! when released, will be reused by consecutive calls to `alloc()`. Also, since ZoneAllocator uses `Zone`, you can +//! turn any `Zone` into a `ZoneAllocator`, and use it in your `Pass` when necessary. +//! +//! ZoneAllocator is used by AsmJit containers to make containers having only few elements fast (and lightweight) +//! and to allow them to grow and use dynamic blocks when require more storage. +class ZoneAllocator { +public: + ASMJIT_NONCOPYABLE(ZoneAllocator) + + //! \cond INTERNAL + + // In short, we pool chunks of these sizes: + // [32, 64, 96, 128, 192, 256, 320, 384, 448, 512] + + enum : uint32_t { + //! How many bytes per a low granularity pool (has to be at least 16). + kLoGranularity = 32, + //! Number of slots of a low granularity pool. + kLoCount = 4, + //! Maximum size of a block that can be allocated in a low granularity pool. + kLoMaxSize = kLoGranularity * kLoCount, + + //! How many bytes per a high granularity pool. + kHiGranularity = 64, + //! Number of slots of a high granularity pool. + kHiCount = 6, + //! Maximum size of a block that can be allocated in a high granularity pool. + kHiMaxSize = kLoMaxSize + kHiGranularity * kHiCount, + + //! Alignment of every pointer returned by `alloc()`. + kBlockAlignment = kLoGranularity + }; + + //! Single-linked list used to store unused chunks. + struct Slot { + //! Link to a next slot in a single-linked list. + Slot* next; + }; + + //! A block of memory that has been allocated dynamically and is not part of block-list used by the allocator. + //! This is used to keep track of all these blocks so they can be freed by `reset()` if not freed explicitly. + struct DynamicBlock { + DynamicBlock* prev; + DynamicBlock* next; + }; + + //! \endcond + + //! \name Members + //! \{ + + //! Zone used to allocate memory that fits into slots. + Zone* _zone {}; + //! Indexed slots containing released memory. + Slot* _slots[kLoCount + kHiCount] {}; + //! Dynamic blocks for larger allocations (no slots). + DynamicBlock* _dynamicBlocks {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `ZoneAllocator`. + //! + //! \note To use it, you must first `init()` it. + ASMJIT_INLINE_NODEBUG ZoneAllocator() noexcept {} + + //! Creates a new `ZoneAllocator` initialized to use `zone`. + ASMJIT_INLINE_NODEBUG explicit ZoneAllocator(Zone* zone) noexcept + : _zone(zone) {} + + //! Destroys the `ZoneAllocator`. + ASMJIT_INLINE_NODEBUG ~ZoneAllocator() noexcept { reset(); } + + //! Tests whether the `ZoneAllocator` is initialized (i.e. has `Zone`). + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _zone != nullptr; } + + //! Convenience function to initialize the `ZoneAllocator` with `zone`. + //! + //! It's the same as calling `reset(zone)`. + ASMJIT_INLINE_NODEBUG void init(Zone* zone) noexcept { reset(zone); } + + //! Resets this `ZoneAllocator` and also forget about the current `Zone` which is attached (if any). Reset + //! optionally attaches a new `zone` passed, or keeps the `ZoneAllocator` in an uninitialized state, if + //! `zone` is null. + ASMJIT_API void reset(Zone* zone = nullptr) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the assigned `Zone` of this allocator or null if this `ZoneAllocator` is not initialized. + ASMJIT_INLINE_NODEBUG Zone* zone() const noexcept { return _zone; } + + //! \} + + //! \cond + //! \name Internals + //! \{ + + //! Returns the slot index to be used for `size`. Returns `true` if a valid slot has been written to `slot` and + //! `allocatedSize` has been filled with slot exact size (`allocatedSize` can be equal or slightly greater than + //! `size`). + static inline bool _getSlotIndex(size_t size, uint32_t& slot) noexcept { + ASMJIT_ASSERT(size > 0); + if (size > kHiMaxSize) + return false; + + if (size <= kLoMaxSize) + slot = uint32_t((size - 1) / kLoGranularity); + else + slot = uint32_t((size - kLoMaxSize - 1) / kHiGranularity) + kLoCount; + + return true; + } + + //! \overload + static inline bool _getSlotIndex(size_t size, uint32_t& slot, size_t& allocatedSize) noexcept { + ASMJIT_ASSERT(size > 0); + if (size > kHiMaxSize) + return false; + + if (size <= kLoMaxSize) { + slot = uint32_t((size - 1) / kLoGranularity); + allocatedSize = Support::alignUp(size, kLoGranularity); + } + else { + slot = uint32_t((size - kLoMaxSize - 1) / kHiGranularity) + kLoCount; + allocatedSize = Support::alignUp(size, kHiGranularity); + } + + return true; + } + + //! \} + //! \endcond + + //! \name Allocation + //! \{ + + //! \cond INTERNAL + ASMJIT_API void* _alloc(size_t size, size_t& allocatedSize) noexcept; + ASMJIT_API void* _allocZeroed(size_t size, size_t& allocatedSize) noexcept; + ASMJIT_API void _releaseDynamic(void* p, size_t size) noexcept; + //! \endcond + + //! Allocates `size` bytes of memory, ideally from an available pool. + //! + //! \note `size` can't be zero, it will assert in debug mode in such case. + inline void* alloc(size_t size) noexcept { + ASMJIT_ASSERT(isInitialized()); + size_t allocatedSize; + return _alloc(size, allocatedSize); + } + + //! Like `alloc(size)`, but provides a second argument `allocatedSize` that provides a way to know how big + //! the block returned actually is. This is useful for containers to prevent growing too early. + inline void* alloc(size_t size, size_t& allocatedSize) noexcept { + ASMJIT_ASSERT(isInitialized()); + return _alloc(size, allocatedSize); + } + + //! Like `alloc()`, but the return pointer is casted to `T*`. + template + inline T* allocT(size_t size = sizeof(T)) noexcept { + return static_cast(alloc(size)); + } + + //! Like `alloc(size)`, but returns zeroed memory. + inline void* allocZeroed(size_t size) noexcept { + ASMJIT_ASSERT(isInitialized()); + size_t allocatedSize; + return _allocZeroed(size, allocatedSize); + } + + //! Like `alloc(size, allocatedSize)`, but returns zeroed memory. + inline void* allocZeroed(size_t size, size_t& allocatedSize) noexcept { + ASMJIT_ASSERT(isInitialized()); + return _allocZeroed(size, allocatedSize); + } + + //! Like `allocZeroed()`, but the return pointer is casted to `T*`. + template + inline T* allocZeroedT(size_t size = sizeof(T)) noexcept { + return static_cast(allocZeroed(size)); + } + + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT() noexcept { + void* p = allocT(); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(); + } + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT(Args&&... args) noexcept { + void* p = allocT(); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(std::forward(args)...); + } + + //! Releases the memory previously allocated by `alloc()`. The `size` argument has to be the same as used to call + //! `alloc()` or `allocatedSize` returned by `alloc()`. + inline void release(void* p, size_t size) noexcept { + ASMJIT_ASSERT(isInitialized()); + ASMJIT_ASSERT(p != nullptr); + ASMJIT_ASSERT(size != 0); + + uint32_t slot; + if (_getSlotIndex(size, slot)) { + static_cast(p)->next = static_cast(_slots[slot]); + _slots[slot] = static_cast(p); + } + else { + _releaseDynamic(p, size); + } + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONE_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonehash.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonehash.h new file mode 100644 index 0000000000000000000000000000000000000000..d6cd2e31cb914f6b2b751d3346949ef127482b2d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonehash.h @@ -0,0 +1,186 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONEHASH_H_INCLUDED +#define ASMJIT_CORE_ZONEHASH_H_INCLUDED + +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Node used by \ref ZoneHash template. +//! +//! You must provide function `bool eq(const Key& key)` in order to make `ZoneHash::get()` working. +class ZoneHashNode { +public: + ASMJIT_NONCOPYABLE(ZoneHashNode) + + inline ZoneHashNode(uint32_t hashCode = 0) noexcept + : _hashNext(nullptr), + _hashCode(hashCode), + _customData(0) {} + + //! Next node in the chain, null if it terminates the chain. + ZoneHashNode* _hashNext; + //! Precalculated hash-code of key. + uint32_t _hashCode; + //! Padding, can be reused by any Node that inherits `ZoneHashNode`. + uint32_t _customData; +}; + +//! Base class used by \ref ZoneHash template +class ZoneHashBase { +public: + ASMJIT_NONCOPYABLE(ZoneHashBase) + + //! Buckets data. + ZoneHashNode** _data; + //! Count of records inserted into the hash table. + size_t _size; + //! Count of hash buckets. + uint32_t _bucketsCount; + //! When buckets array should grow (only checked after insertion). + uint32_t _bucketsGrow; + //! Reciprocal value of `_bucketsCount`. + uint32_t _rcpValue; + //! How many bits to shift right when hash is multiplied with `_rcpValue`. + uint8_t _rcpShift; + //! Prime value index in internal prime array. + uint8_t _primeIndex; + + //! Embedded data, used by empty hash tables. + ZoneHashNode* _embedded[1]; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneHashBase() noexcept { + reset(); + } + + inline ZoneHashBase(ZoneHashBase&& other) noexcept { + _data = other._data; + _size = other._size; + _bucketsCount = other._bucketsCount; + _bucketsGrow = other._bucketsGrow; + _rcpValue = other._rcpValue; + _rcpShift = other._rcpShift; + _primeIndex = other._primeIndex; + _embedded[0] = other._embedded[0]; + + if (_data == other._embedded) _data = _embedded; + } + + inline void reset() noexcept { + _data = _embedded; + _size = 0; + _bucketsCount = 1; + _bucketsGrow = 1; + _rcpValue = 1; + _rcpShift = 0; + _primeIndex = 0; + _embedded[0] = nullptr; + } + + inline void release(ZoneAllocator* allocator) noexcept { + ZoneHashNode** oldData = _data; + if (oldData != _embedded) + allocator->release(oldData, _bucketsCount * sizeof(ZoneHashNode*)); + reset(); + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + + //! \} + + //! \name Utilities + //! \{ + + inline void _swap(ZoneHashBase& other) noexcept { + std::swap(_data, other._data); + std::swap(_size, other._size); + std::swap(_bucketsCount, other._bucketsCount); + std::swap(_bucketsGrow, other._bucketsGrow); + std::swap(_rcpValue, other._rcpValue); + std::swap(_rcpShift, other._rcpShift); + std::swap(_primeIndex, other._primeIndex); + std::swap(_embedded[0], other._embedded[0]); + + if (_data == other._embedded) _data = _embedded; + if (other._data == _embedded) other._data = other._embedded; + } + + //! \cond INTERNAL + inline uint32_t _calcMod(uint32_t hash) const noexcept { + uint32_t x = uint32_t((uint64_t(hash) * _rcpValue) >> _rcpShift); + return hash - x * _bucketsCount; + } + + ASMJIT_API void _rehash(ZoneAllocator* allocator, uint32_t newCount) noexcept; + ASMJIT_API ZoneHashNode* _insert(ZoneAllocator* allocator, ZoneHashNode* node) noexcept; + ASMJIT_API ZoneHashNode* _remove(ZoneAllocator* allocator, ZoneHashNode* node) noexcept; + //! \endcond + + //! \} +}; + +//! Low-level hash table specialized for storing string keys and POD values. +//! +//! This hash table allows duplicates to be inserted (the API is so low level that it's up to you if you allow it or +//! not, as you should first `get()` the node and then modify it or insert a new node by using `insert()`, depending +//! on the intention). +template +class ZoneHash : public ZoneHashBase { +public: + ASMJIT_NONCOPYABLE(ZoneHash) + + typedef NodeT Node; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneHash() noexcept + : ZoneHashBase() {} + + ASMJIT_INLINE_NODEBUG ZoneHash(ZoneHash&& other) noexcept + : ZoneHash(other) {} + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneHash& other) noexcept { ZoneHashBase::_swap(other); } + + template + inline NodeT* get(const KeyT& key) const noexcept { + uint32_t hashMod = _calcMod(key.hashCode()); + NodeT* node = static_cast(_data[hashMod]); + + while (node && !key.matches(node)) + node = static_cast(node->_hashNext); + return node; + } + + ASMJIT_INLINE_NODEBUG NodeT* insert(ZoneAllocator* allocator, NodeT* node) noexcept { return static_cast(_insert(allocator, node)); } + ASMJIT_INLINE_NODEBUG NodeT* remove(ZoneAllocator* allocator, NodeT* node) noexcept { return static_cast(_remove(allocator, node)); } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONEHASH_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonelist.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonelist.h new file mode 100644 index 0000000000000000000000000000000000000000..8980240ef431367f5b1d27f0fbd02705edfe0d12 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonelist.h @@ -0,0 +1,208 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONELIST_H_INCLUDED +#define ASMJIT_CORE_ZONELIST_H_INCLUDED + +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Node used by \ref ZoneList template. +template +class ZoneListNode { +public: + ASMJIT_NONCOPYABLE(ZoneListNode) + + //! \name Constants + //! \{ + + enum : size_t { + kNodeIndexPrev = 0, + kNodeIndexNext = 1 + }; + + //! \} + + //! \name Members + //! \{ + + NodeT* _listNodes[2]; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneListNode() noexcept + : _listNodes { nullptr, nullptr } {} + + ASMJIT_INLINE_NODEBUG ZoneListNode(ZoneListNode&& other) noexcept + : _listNodes { other._listNodes[0], other._listNodes[1] } {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool hasPrev() const noexcept { return _listNodes[kNodeIndexPrev] != nullptr; } + ASMJIT_INLINE_NODEBUG bool hasNext() const noexcept { return _listNodes[kNodeIndexNext] != nullptr; } + + ASMJIT_INLINE_NODEBUG NodeT* prev() const noexcept { return _listNodes[kNodeIndexPrev]; } + ASMJIT_INLINE_NODEBUG NodeT* next() const noexcept { return _listNodes[kNodeIndexNext]; } + + //! \} +}; + +//! Zone allocated list container that uses nodes of `NodeT` type. +template +class ZoneList { +public: + ASMJIT_NONCOPYABLE(ZoneList) + + //! \name Constants + //! \{ + + enum : size_t { + kNodeIndexFirst = 0, + kNodeIndexLast = 1 + }; + + //! \} + + //! \name Members + //! \{ + + NodeT* _nodes[2] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneList() noexcept {} + + ASMJIT_INLINE_NODEBUG ZoneList(ZoneList&& other) noexcept + : _nodes { other._nodes[0], other._nodes[1] } {} + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _nodes[0] = nullptr; + _nodes[1] = nullptr; + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _nodes[0] == nullptr; } + ASMJIT_INLINE_NODEBUG NodeT* first() const noexcept { return _nodes[kNodeIndexFirst]; } + ASMJIT_INLINE_NODEBUG NodeT* last() const noexcept { return _nodes[kNodeIndexLast]; } + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneList& other) noexcept { + std::swap(_nodes[0], other._nodes[0]); + std::swap(_nodes[1], other._nodes[1]); + } + + // Can be used to both append and prepend. + inline void _addNode(NodeT* node, size_t dir) noexcept { + NodeT* prev = _nodes[dir]; + + node->_listNodes[!dir] = prev; + _nodes[dir] = node; + if (prev) + prev->_listNodes[dir] = node; + else + _nodes[!dir] = node; + } + + // Can be used to both append and prepend. + inline void _insertNode(NodeT* ref, NodeT* node, size_t dir) noexcept { + ASMJIT_ASSERT(ref != nullptr); + + NodeT* prev = ref; + NodeT* next = ref->_listNodes[dir]; + + prev->_listNodes[dir] = node; + if (next) + next->_listNodes[!dir] = node; + else + _nodes[dir] = node; + + node->_listNodes[!dir] = prev; + node->_listNodes[ dir] = next; + } + + ASMJIT_INLINE_NODEBUG void append(NodeT* node) noexcept { _addNode(node, kNodeIndexLast); } + ASMJIT_INLINE_NODEBUG void prepend(NodeT* node) noexcept { _addNode(node, kNodeIndexFirst); } + + ASMJIT_INLINE_NODEBUG void insertAfter(NodeT* ref, NodeT* node) noexcept { _insertNode(ref, node, NodeT::kNodeIndexNext); } + ASMJIT_INLINE_NODEBUG void insertBefore(NodeT* ref, NodeT* node) noexcept { _insertNode(ref, node, NodeT::kNodeIndexPrev); } + + inline NodeT* unlink(NodeT* node) noexcept { + NodeT* prev = node->prev(); + NodeT* next = node->next(); + + if (prev) { prev->_listNodes[1] = next; node->_listNodes[0] = nullptr; } else { _nodes[0] = next; } + if (next) { next->_listNodes[0] = prev; node->_listNodes[1] = nullptr; } else { _nodes[1] = prev; } + + node->_listNodes[0] = nullptr; + node->_listNodes[1] = nullptr; + + return node; + } + + inline NodeT* popFirst() noexcept { + NodeT* node = _nodes[0]; + ASMJIT_ASSERT(node != nullptr); + + NodeT* next = node->next(); + _nodes[0] = next; + + if (next) { + next->_listNodes[0] = nullptr; + node->_listNodes[1] = nullptr; + } + else { + _nodes[1] = nullptr; + } + + return node; + } + + inline NodeT* pop() noexcept { + NodeT* node = _nodes[1]; + ASMJIT_ASSERT(node != nullptr); + + NodeT* prev = node->prev(); + _nodes[1] = prev; + + if (prev) { + prev->_listNodes[1] = nullptr; + node->_listNodes[0] = nullptr; + } + else { + _nodes[0] = nullptr; + } + + return node; + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONELIST_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonestack.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonestack.h new file mode 100644 index 0000000000000000000000000000000000000000..2cf078b3d219e2ef73c1887817f348618c4ced3a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonestack.h @@ -0,0 +1,235 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONESTACK_H_INCLUDED +#define ASMJIT_CORE_ZONESTACK_H_INCLUDED + +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Base class used by \ref ZoneStack. +class ZoneStackBase { +public: + ASMJIT_NONCOPYABLE(ZoneStackBase) + + //! \name Constants + //! \{ + + enum : size_t { + kBlockIndexPrev = 0, + kBlockIndexNext = 1, + + kBlockIndexFirst = 0, + kBlockIndexLast = 1, + + kBlockSize = ZoneAllocator::kHiMaxSize + }; + + //! \} + + //! \name Types + //! \{ + + struct Block { + //! Next and previous blocks. + Block* _link[2]; + //! Pointer to the start of the array. + void* _start; + //! Pointer to the end of the array. + void* _end; + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _start == _end; } + ASMJIT_INLINE_NODEBUG Block* prev() const noexcept { return _link[kBlockIndexPrev]; } + ASMJIT_INLINE_NODEBUG Block* next() const noexcept { return _link[kBlockIndexNext]; } + + ASMJIT_INLINE_NODEBUG void setPrev(Block* block) noexcept { _link[kBlockIndexPrev] = block; } + ASMJIT_INLINE_NODEBUG void setNext(Block* block) noexcept { _link[kBlockIndexNext] = block; } + + template + ASMJIT_INLINE_NODEBUG T* start() const noexcept { return static_cast(_start); } + template + ASMJIT_INLINE_NODEBUG void setStart(T* start) noexcept { _start = static_cast(start); } + + template + ASMJIT_INLINE_NODEBUG T* end() const noexcept { return (T*)_end; } + template + ASMJIT_INLINE_NODEBUG void setEnd(T* end) noexcept { _end = (void*)end; } + + template + ASMJIT_INLINE_NODEBUG T* data() const noexcept { return (T*)((uint8_t*)(this) + sizeof(Block)); } + + template + ASMJIT_INLINE_NODEBUG bool canPrepend() const noexcept { return _start > data(); } + + template + ASMJIT_INLINE_NODEBUG bool canAppend() const noexcept { + size_t kNumBlockItems = (kBlockSize - sizeof(Block)) / sizeof(T); + size_t kStartBlockIndex = sizeof(Block); + size_t kEndBlockIndex = kStartBlockIndex + kNumBlockItems * sizeof(T); + + return (uintptr_t)_end <= ((uintptr_t)this + kEndBlockIndex - sizeof(T)); + } + }; + + //! \} + + //! \name Members + //! \{ + + //! Allocator used to allocate data. + ZoneAllocator* _allocator {}; + //! First and last blocks. + Block* _block[2] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneStackBase() noexcept {} + ASMJIT_INLINE_NODEBUG ~ZoneStackBase() noexcept { reset(); } + + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _allocator != nullptr; } + ASMJIT_API Error _init(ZoneAllocator* allocator, size_t middleIndex) noexcept; + ASMJIT_INLINE_NODEBUG Error reset() noexcept { return _init(nullptr, 0); } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns `ZoneAllocator` attached to this container. + ASMJIT_INLINE_NODEBUG ZoneAllocator* allocator() const noexcept { return _allocator; } + + inline bool empty() const noexcept { + ASMJIT_ASSERT(isInitialized()); + return _block[0]->start() == _block[1]->end(); + } + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + ASMJIT_API Error _prepareBlock(uint32_t side, size_t initialIndex) noexcept; + ASMJIT_API void _cleanupBlock(uint32_t side, size_t middleIndex) noexcept; + + //! \} + //! \endcond +}; + +//! Zone allocated stack container. +template +class ZoneStack : public ZoneStackBase { +public: + ASMJIT_NONCOPYABLE(ZoneStack) + + //! \name Constants + //! \{ + + enum : uint32_t { + kNumBlockItems = uint32_t((kBlockSize - sizeof(Block)) / sizeof(T)), + kStartBlockIndex = uint32_t(sizeof(Block)), + kMidBlockIndex = uint32_t(kStartBlockIndex + (kNumBlockItems / 2) * sizeof(T)), + kEndBlockIndex = uint32_t(kStartBlockIndex + (kNumBlockItems ) * sizeof(T)) + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + inline ZoneStack() noexcept {} + inline ~ZoneStack() noexcept {} + + inline Error init(ZoneAllocator* allocator) noexcept { return _init(allocator, kMidBlockIndex); } + + //! \} + + //! \name Utilities + //! \{ + + inline Error prepend(T item) noexcept { + ASMJIT_ASSERT(isInitialized()); + Block* block = _block[kBlockIndexFirst]; + + if (!block->canPrepend()) { + ASMJIT_PROPAGATE(_prepareBlock(kBlockIndexFirst, kEndBlockIndex)); + block = _block[kBlockIndexFirst]; + } + + T* ptr = block->start() - 1; + ASMJIT_ASSERT(ptr >= block->data() && ptr <= block->data() + (kNumBlockItems - 1)); + *ptr = item; + block->setStart(ptr); + return kErrorOk; + } + + inline Error append(T item) noexcept { + ASMJIT_ASSERT(isInitialized()); + Block* block = _block[kBlockIndexLast]; + + if (!block->canAppend()) { + ASMJIT_PROPAGATE(_prepareBlock(kBlockIndexLast, kStartBlockIndex)); + block = _block[kBlockIndexLast]; + } + + T* ptr = block->end(); + ASMJIT_ASSERT(ptr >= block->data() && ptr <= block->data() + (kNumBlockItems - 1)); + + *ptr++ = item; + block->setEnd(ptr); + return kErrorOk; + } + + inline T popFirst() noexcept { + ASMJIT_ASSERT(isInitialized()); + ASMJIT_ASSERT(!empty()); + + Block* block = _block[kBlockIndexFirst]; + ASMJIT_ASSERT(!block->empty()); + + T* ptr = block->start(); + T item = *ptr++; + + block->setStart(ptr); + if (block->empty()) + _cleanupBlock(kBlockIndexFirst, kMidBlockIndex); + + return item; + } + + inline T pop() noexcept { + ASMJIT_ASSERT(isInitialized()); + ASMJIT_ASSERT(!empty()); + + Block* block = _block[kBlockIndexLast]; + ASMJIT_ASSERT(!block->empty()); + + T* ptr = block->end(); + T item = *--ptr; + ASMJIT_ASSERT(ptr >= block->data()); + ASMJIT_ASSERT(ptr >= block->start()); + + block->setEnd(ptr); + if (block->empty()) + _cleanupBlock(kBlockIndexLast, kMidBlockIndex); + + return item; + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONESTACK_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonestring.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonestring.h new file mode 100644 index 0000000000000000000000000000000000000000..e62ac50f287bd5a8208a07d3fbb874f3adf32d62 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonestring.h @@ -0,0 +1,120 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONESTRING_H_INCLUDED +#define ASMJIT_CORE_ZONESTRING_H_INCLUDED + +#include "../core/globals.h" +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! A helper class used by \ref ZoneString implementation. +struct ZoneStringBase { + union { + struct { + uint32_t _size; + char _embedded[sizeof(void*) * 2 - 4]; + }; + struct { + void* _dummy; + char* _external; + }; + }; + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _dummy = nullptr; + _external = nullptr; + } + + Error setData(Zone* zone, uint32_t maxEmbeddedSize, const char* str, size_t size) noexcept { + if (size == SIZE_MAX) + size = strlen(str); + + if (size <= maxEmbeddedSize) { + memcpy(_embedded, str, size); + _embedded[size] = '\0'; + } + else { + char* external = static_cast(zone->dup(str, size, true)); + if (ASMJIT_UNLIKELY(!external)) + return DebugUtils::errored(kErrorOutOfMemory); + _external = external; + } + + _size = uint32_t(size); + return kErrorOk; + } +}; + +//! A string template that can be zone allocated. +//! +//! Helps with creating strings that can be either statically allocated if they are small, or externally allocated +//! in case their size exceeds the limit. The `N` represents the size of the whole `ZoneString` structure, based on +//! that size the maximum size of the internal buffer is determined. +template +class ZoneString { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + kWholeSize = (N > sizeof(ZoneStringBase)) ? uint32_t(N) : uint32_t(sizeof(ZoneStringBase)), + kMaxEmbeddedSize = kWholeSize - 5 + }; + + //! \} + + //! \name Members + //! \{ + + union { + ZoneStringBase _base; + char _wholeData[kWholeSize]; + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneString() noexcept { reset(); } + ASMJIT_INLINE_NODEBUG void reset() noexcept { _base.reset(); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the string is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _base._size == 0; } + + //! Returns the string data. + ASMJIT_INLINE_NODEBUG const char* data() const noexcept { return _base._size <= kMaxEmbeddedSize ? _base._embedded : _base._external; } + //! Returns the string size. + ASMJIT_INLINE_NODEBUG uint32_t size() const noexcept { return _base._size; } + + //! Tests whether the string is embedded (e.g. no dynamically allocated). + ASMJIT_INLINE_NODEBUG bool isEmbedded() const noexcept { return _base._size <= kMaxEmbeddedSize; } + + //! Copies a new `data` of the given `size` to the string. + //! + //! If the `size` exceeds the internal buffer the given `zone` will be used to duplicate the data, otherwise + //! the internal buffer will be used as a storage. + ASMJIT_INLINE_NODEBUG Error setData(Zone* zone, const char* data, size_t size) noexcept { + return _base.setData(zone, kMaxEmbeddedSize, data, size); + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONESTRING_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonetree.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonetree.h new file mode 100644 index 0000000000000000000000000000000000000000..ffeb674cfe2804609929a22d14b49633cad183a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonetree.h @@ -0,0 +1,376 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONETREE_H_INCLUDED +#define ASMJIT_CORE_ZONETREE_H_INCLUDED + +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! RB-Tree node. +//! +//! The color is stored in a least significant bit of the `left` node. +//! +//! WARNING: Always use accessors to access left and right children. +class ZoneTreeNode { +public: + ASMJIT_NONCOPYABLE(ZoneTreeNode) + + //! \name Constants + //! \{ + + enum : uintptr_t { + kRedMask = 0x1, + kPtrMask = ~kRedMask + }; + + //! \} + + //! \name Members + //! \{ + + uintptr_t _rbNodeData[2] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTreeNode() noexcept {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool isRed() const noexcept { return static_cast(_rbNodeData[0] & kRedMask); } + + ASMJIT_INLINE_NODEBUG bool hasChild(size_t i) const noexcept { return _rbNodeData[i] > kRedMask; } + ASMJIT_INLINE_NODEBUG bool hasLeft() const noexcept { return _rbNodeData[0] > kRedMask; } + ASMJIT_INLINE_NODEBUG bool hasRight() const noexcept { return _rbNodeData[1] != 0; } + + template + ASMJIT_INLINE_NODEBUG T* child(size_t i) const noexcept { return static_cast(_getChild(i)); } + template + ASMJIT_INLINE_NODEBUG T* left() const noexcept { return static_cast(_getLeft()); } + template + ASMJIT_INLINE_NODEBUG T* right() const noexcept { return static_cast(_getRight()); } + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTreeNode* _getChild(size_t i) const noexcept { return (ZoneTreeNode*)(_rbNodeData[i] & kPtrMask); } + ASMJIT_INLINE_NODEBUG ZoneTreeNode* _getLeft() const noexcept { return (ZoneTreeNode*)(_rbNodeData[0] & kPtrMask); } + ASMJIT_INLINE_NODEBUG ZoneTreeNode* _getRight() const noexcept { return (ZoneTreeNode*)(_rbNodeData[1]); } + + ASMJIT_INLINE_NODEBUG void _setChild(size_t i, ZoneTreeNode* node) noexcept { _rbNodeData[i] = (_rbNodeData[i] & kRedMask) | (uintptr_t)node; } + ASMJIT_INLINE_NODEBUG void _setLeft(ZoneTreeNode* node) noexcept { _rbNodeData[0] = (_rbNodeData[0] & kRedMask) | (uintptr_t)node; } + ASMJIT_INLINE_NODEBUG void _setRight(ZoneTreeNode* node) noexcept { _rbNodeData[1] = (uintptr_t)node; } + + ASMJIT_INLINE_NODEBUG void _makeRed() noexcept { _rbNodeData[0] |= kRedMask; } + ASMJIT_INLINE_NODEBUG void _makeBlack() noexcept { _rbNodeData[0] &= kPtrMask; } + + //! Tests whether the node is RED (RED node must be non-null and must have RED flag set). + static ASMJIT_INLINE_NODEBUG bool _isValidRed(ZoneTreeNode* node) noexcept { return node && node->isRed(); } + + //! \} + //! \endcond +}; + +//! RB-Tree node casted to `NodeT`. +template +class ZoneTreeNodeT : public ZoneTreeNode { +public: + ASMJIT_NONCOPYABLE(ZoneTreeNodeT) + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTreeNodeT() noexcept + : ZoneTreeNode() {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG NodeT* child(size_t i) const noexcept { return static_cast(_getChild(i)); } + ASMJIT_INLINE_NODEBUG NodeT* left() const noexcept { return static_cast(_getLeft()); } + ASMJIT_INLINE_NODEBUG NodeT* right() const noexcept { return static_cast(_getRight()); } + + //! \} +}; + +//! RB-Tree. +template +class ZoneTree { +public: + ASMJIT_NONCOPYABLE(ZoneTree) + + typedef NodeT Node; + NodeT* _root {}; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTree() noexcept {} + ASMJIT_INLINE_NODEBUG ZoneTree(ZoneTree&& other) noexcept + : _root(other._root) {} + ASMJIT_INLINE_NODEBUG void reset() noexcept { _root = nullptr; } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _root == nullptr; } + ASMJIT_INLINE_NODEBUG NodeT* root() const noexcept { return static_cast(_root); } + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneTree& other) noexcept { + std::swap(_root, other._root); + } + + template> + void insert(NodeT* ASMJIT_NONNULL(node), const CompareT& cmp = CompareT()) noexcept { + // Node to insert must not contain garbage. + ASMJIT_ASSERT(!node->hasLeft()); + ASMJIT_ASSERT(!node->hasRight()); + ASMJIT_ASSERT(!node->isRed()); + + if (!_root) { + _root = node; + return; + } + + ZoneTreeNode head; // False root node, + head._setRight(_root); // having root on the right. + + ZoneTreeNode* g = nullptr; // Grandparent. + ZoneTreeNode* p = nullptr; // Parent. + ZoneTreeNode* t = &head; // Iterator. + ZoneTreeNode* q = _root; // Query. + + size_t dir = 0; // Direction for accessing child nodes. + size_t last = 0; // Not needed to initialize, but makes some tools happy. + + node->_makeRed(); // New nodes are always red and violations fixed appropriately. + + // Search down the tree. + for (;;) { + if (!q) { + // Insert new node at the bottom. + q = node; + p->_setChild(dir, node); + } + else if (_isValidRed(q->_getLeft()) && _isValidRed(q->_getRight())) { + // Color flip. + q->_makeRed(); + q->_getLeft()->_makeBlack(); + q->_getRight()->_makeBlack(); + } + + // Fix red violation. + if (_isValidRed(q) && _isValidRed(p)) { + ASMJIT_ASSUME(g != nullptr); + ASMJIT_ASSUME(p != nullptr); + t->_setChild(t->_getRight() == g, + q == p->_getChild(last) ? _singleRotate(g, !last) : _doubleRotate(g, !last)); + } + + // Stop if found. + if (q == node) + break; + + last = dir; + dir = cmp(*static_cast(q), *static_cast(node)) < 0; + + // Update helpers. + if (g) t = g; + + g = p; + p = q; + q = q->_getChild(dir); + } + + // Update root and make it black. + _root = static_cast(head._getRight()); + _root->_makeBlack(); + } + + //! Remove node from RBTree. + template> + void remove(ZoneTreeNode* ASMJIT_NONNULL(node), const CompareT& cmp = CompareT()) noexcept { + ZoneTreeNode head; // False root node, + head._setRight(_root); // having root on the right. + + ZoneTreeNode* g = nullptr; // Grandparent. + ZoneTreeNode* p = nullptr; // Parent. + ZoneTreeNode* q = &head; // Query. + + ZoneTreeNode* f = nullptr; // Found item. + ZoneTreeNode* gf = nullptr; // Found grandparent. + size_t dir = 1; // Direction (0 or 1). + + // Search and push a red down. + while (q->hasChild(dir)) { + size_t last = dir; + + // Update helpers. + g = p; + p = q; + q = q->_getChild(dir); + dir = cmp(*static_cast(q), *static_cast(node)) < 0; + + // Save found node. + if (q == node) { + f = q; + gf = g; + } + + // Push the red node down. + if (!_isValidRed(q) && !_isValidRed(q->_getChild(dir))) { + if (_isValidRed(q->_getChild(!dir))) { + ZoneTreeNode* child = _singleRotate(q, dir); + p->_setChild(last, child); + p = child; + } + else if (!_isValidRed(q->_getChild(!dir)) && p->_getChild(!last)) { + ZoneTreeNode* s = p->_getChild(!last); + if (!_isValidRed(s->_getChild(!last)) && !_isValidRed(s->_getChild(last))) { + // Color flip. + p->_makeBlack(); + s->_makeRed(); + q->_makeRed(); + } + else { + ASMJIT_ASSUME(g != nullptr); + ASMJIT_ASSUME(s != nullptr); + + size_t dir2 = g->_getRight() == p; + ZoneTreeNode* child = g->_getChild(dir2); + + if (_isValidRed(s->_getChild(last))) { + child = _doubleRotate(p, last); + g->_setChild(dir2, child); + } + else if (_isValidRed(s->_getChild(!last))) { + child = _singleRotate(p, last); + g->_setChild(dir2, child); + } + + // Ensure correct coloring. + q->_makeRed(); + child->_makeRed(); + child->_getLeft()->_makeBlack(); + child->_getRight()->_makeBlack(); + } + } + } + } + + // Replace and remove. + ASMJIT_ASSERT(f != nullptr); + ASMJIT_ASSERT(f != &head); + ASMJIT_ASSERT(q != &head); + + p->_setChild(p->_getRight() == q, + q->_getChild(q->_getLeft() == nullptr)); + + // NOTE: The original algorithm used a trick to just copy 'key/value' to `f` and mark `q` for deletion. But this + // is unacceptable here as we really want to destroy the passed `node`. So, we have to make sure that we have + // really removed `f` and not `q`. + if (f != q) { + ASMJIT_ASSERT(f != &head); + ASMJIT_ASSERT(f != gf); + + ZoneTreeNode* n = gf ? gf : &head; + dir = (n == &head) ? 1 : cmp(*static_cast(n), *static_cast(node)) < 0; + + for (;;) { + if (n->_getChild(dir) == f) { + n->_setChild(dir, q); + // RAW copy, including the color. + q->_rbNodeData[0] = f->_rbNodeData[0]; + q->_rbNodeData[1] = f->_rbNodeData[1]; + break; + } + + n = n->_getChild(dir); + + // Cannot be true as we know that it must reach `f` in few iterations. + ASMJIT_ASSERT(n != nullptr); + dir = cmp(*static_cast(n), *static_cast(node)) < 0; + } + } + + // Update root and make it black. + _root = static_cast(head._getRight()); + if (_root) _root->_makeBlack(); + } + + template> + inline NodeT* get(const KeyT& key, const CompareT& cmp = CompareT()) const noexcept { + ZoneTreeNode* node = _root; + while (node) { + auto result = cmp(*static_cast(node), key); + if (result == 0) break; + + // Go left or right depending on the `result`. + node = node->_getChild(result < 0); + } + return static_cast(node); + } + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + static inline bool _isValidRed(ZoneTreeNode* node) noexcept { return ZoneTreeNode::_isValidRed(node); } + + //! Single rotation. + static inline ZoneTreeNode* _singleRotate(ZoneTreeNode* ASMJIT_NONNULL(root), size_t dir) noexcept { + ZoneTreeNode* save = root->_getChild(!dir); + ASMJIT_ASSUME(save != nullptr); + + ZoneTreeNode* saveChild = save->_getChild(dir); + root->_setChild(!dir, saveChild); + save->_setChild( dir, root); + root->_makeRed(); + save->_makeBlack(); + return save; + } + + //! Double rotation. + static inline ZoneTreeNode* _doubleRotate(ZoneTreeNode* ASMJIT_NONNULL(root), size_t dir) noexcept { + ZoneTreeNode* child = root->_getChild(!dir); + ASMJIT_ASSUME(child != nullptr); + + root->_setChild(!dir, _singleRotate(child, !dir)); + return _singleRotate(root, dir); + } + + //! \} + //! \endcond +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONETREE_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonevector.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonevector.h new file mode 100644 index 0000000000000000000000000000000000000000..13d28bbefa2299ce276fb6f5b7f9d89f164bd7f1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/core/zonevector.h @@ -0,0 +1,729 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONEVECTOR_H_INCLUDED +#define ASMJIT_CORE_ZONEVECTOR_H_INCLUDED + +#include "../core/support.h" +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Base class used by \ref ZoneVector template. +class ZoneVectorBase { +public: + ASMJIT_NONCOPYABLE(ZoneVectorBase) + + // STL compatibility; + typedef uint32_t size_type; + typedef ptrdiff_t difference_type; + + //! Vector data (untyped). + void* _data = nullptr; + //! Size of the vector. + size_type _size = 0; + //! Capacity of the vector. + size_type _capacity = 0; + +protected: + //! \name Construction & Destruction + //! \{ + + //! Creates a new instance of `ZoneVectorBase`. + inline ZoneVectorBase() noexcept {} + + inline ZoneVectorBase(ZoneVectorBase&& other) noexcept + : _data(other._data), + _size(other._size), + _capacity(other._capacity) {} + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + inline void _release(ZoneAllocator* allocator, uint32_t sizeOfT) noexcept { + if (_data != nullptr) { + allocator->release(_data, _capacity * sizeOfT); + reset(); + } + } + + ASMJIT_API Error _grow(ZoneAllocator* allocator, uint32_t sizeOfT, uint32_t n) noexcept; + ASMJIT_API Error _resize(ZoneAllocator* allocator, uint32_t sizeOfT, uint32_t n) noexcept; + ASMJIT_API Error _reserve(ZoneAllocator* allocator, uint32_t sizeOfT, uint32_t n) noexcept; + + inline void _swap(ZoneVectorBase& other) noexcept { + std::swap(_data, other._data); + std::swap(_size, other._size); + std::swap(_capacity, other._capacity); + } + + //! \} + //! \endcond + +public: + //! \name Accessors + //! \{ + + //! Tests whether the vector is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + //! Returns the vector size. + ASMJIT_INLINE_NODEBUG size_type size() const noexcept { return _size; } + //! Returns the vector capacity. + ASMJIT_INLINE_NODEBUG size_type capacity() const noexcept { return _capacity; } + + //! \} + + //! \name Utilities + //! \{ + + //! Makes the vector empty (won't change the capacity or data pointer). + ASMJIT_INLINE_NODEBUG void clear() noexcept { _size = 0; } + //! Resets the vector data and set its `size` to zero. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _data = nullptr; + _size = 0; + _capacity = 0; + } + + //! Truncates the vector to at most `n` items. + ASMJIT_INLINE_NODEBUG void truncate(size_type n) noexcept { + _size = Support::min(_size, n); + } + + //! Sets size of the vector to `n`. Used internally by some algorithms. + inline void _setSize(size_type n) noexcept { + ASMJIT_ASSERT(n <= _capacity); + _size = n; + } + + //! \} +}; + +//! Template used to store and manage array of Zone allocated data. +//! +//! This template has these advantages over other std::vector<>: +//! - Always non-copyable (designed to be non-copyable, we want it). +//! - Optimized for working only with POD types. +//! - Uses ZoneAllocator, thus small vectors are almost for free. +//! - Explicit allocation, ZoneAllocator is not part of the data. +template +class ZoneVector : public ZoneVectorBase { +public: + ASMJIT_NONCOPYABLE(ZoneVector) + + // STL compatibility; + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + + typedef T* iterator; + typedef const T* const_iterator; + typedef Support::ArrayReverseIterator reverse_iterator; + typedef Support::ArrayReverseIterator const_reverse_iterator; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneVector() noexcept : ZoneVectorBase() {} + ASMJIT_INLINE_NODEBUG ZoneVector(ZoneVector&& other) noexcept : ZoneVector(other) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns vector data. + ASMJIT_INLINE_NODEBUG T* data() noexcept { return static_cast(_data); } + //! Returns vector data (const) + ASMJIT_INLINE_NODEBUG const T* data() const noexcept { return static_cast(_data); } + + //! Returns item at the given index `i` (const). + inline const T& at(size_t i) const noexcept { + ASMJIT_ASSERT(i < _size); + return data()[i]; + } + + inline void _setEndPtr(T* p) noexcept { + ASMJIT_ASSERT(p >= data() && p <= data() + _capacity); + _setSize(uint32_t((uintptr_t)(p - data()))); + } + + //! \} + + //! \name STL Compatibility (Iterators) + //! \{ + + ASMJIT_INLINE_NODEBUG iterator begin() noexcept { return iterator(data()); }; + ASMJIT_INLINE_NODEBUG const_iterator begin() const noexcept { return const_iterator(data()); }; + + ASMJIT_INLINE_NODEBUG iterator end() noexcept { return iterator(data() + _size); }; + ASMJIT_INLINE_NODEBUG const_iterator end() const noexcept { return const_iterator(data() + _size); }; + + ASMJIT_INLINE_NODEBUG reverse_iterator rbegin() noexcept { return reverse_iterator(end()); }; + ASMJIT_INLINE_NODEBUG const_reverse_iterator rbegin() const noexcept { return const_reverse_iterator(end()); }; + + ASMJIT_INLINE_NODEBUG reverse_iterator rend() noexcept { return reverse_iterator(begin()); }; + ASMJIT_INLINE_NODEBUG const_reverse_iterator rend() const noexcept { return const_reverse_iterator(begin()); }; + + ASMJIT_INLINE_NODEBUG const_iterator cbegin() const noexcept { return const_iterator(data()); }; + ASMJIT_INLINE_NODEBUG const_iterator cend() const noexcept { return const_iterator(data() + _size); }; + + ASMJIT_INLINE_NODEBUG const_reverse_iterator crbegin() const noexcept { return const_reverse_iterator(cend()); }; + ASMJIT_INLINE_NODEBUG const_reverse_iterator crend() const noexcept { return const_reverse_iterator(cbegin()); }; + + //! \} + + //! \name Utilities + //! \{ + + //! Swaps this vector with `other`. + ASMJIT_FORCE_INLINE void swap(ZoneVector& other) noexcept { _swap(other); } + + //! Prepends `item` to the vector. + ASMJIT_FORCE_INLINE Error prepend(ZoneAllocator* allocator, const T& item) noexcept { + if (ASMJIT_UNLIKELY(_size == _capacity)) + ASMJIT_PROPAGATE(grow(allocator, 1)); + + memmove(static_cast(static_cast(_data) + 1), + static_cast(_data), + size_t(_size) * sizeof(T)); + + memcpy(static_cast(_data), + static_cast(&item), + sizeof(T)); + + _size++; + return kErrorOk; + } + + //! Inserts an `item` at the specified `index`. + ASMJIT_FORCE_INLINE Error insert(ZoneAllocator* allocator, size_t index, const T& item) noexcept { + ASMJIT_ASSERT(index <= _size); + + if (ASMJIT_UNLIKELY(_size == _capacity)) + ASMJIT_PROPAGATE(grow(allocator, 1)); + + T* dst = static_cast(_data) + index; + memmove(static_cast(dst + 1), + static_cast(dst), + size_t(_size - index) * sizeof(T)); + + memcpy(static_cast(dst), + static_cast(&item), + sizeof(T)); + + _size++; + return kErrorOk; + } + + //! Appends `item` to the vector. + ASMJIT_FORCE_INLINE Error append(ZoneAllocator* allocator, const T& item) noexcept { + if (ASMJIT_UNLIKELY(_size == _capacity)) + ASMJIT_PROPAGATE(grow(allocator, 1)); + + memcpy(static_cast(static_cast(_data) + _size), + static_cast(&item), + sizeof(T)); + + _size++; + return kErrorOk; + } + + //! Appends `other` vector at the end of this vector. + ASMJIT_FORCE_INLINE Error concat(ZoneAllocator* allocator, const ZoneVector& other) noexcept { + uint32_t size = other._size; + if (_capacity - _size < size) + ASMJIT_PROPAGATE(grow(allocator, size)); + + if (size) { + memcpy(static_cast(static_cast(_data) + _size), + static_cast(other._data), + size_t(size) * sizeof(T)); + _size += size; + } + + return kErrorOk; + } + + //! Prepends `item` to the vector (unsafe case). + //! + //! Can only be used together with `willGrow()`. If `willGrow(N)` returns `kErrorOk` then N elements + //! can be added to the vector without checking if there is a place for them. Used mostly internally. + ASMJIT_FORCE_INLINE void prependUnsafe(const T& item) noexcept { + ASMJIT_ASSERT(_size < _capacity); + T* data = static_cast(_data); + + if (_size) { + memmove(static_cast(data + 1), + static_cast(data), + size_t(_size) * sizeof(T)); + } + + memcpy(static_cast(data), + static_cast(&item), + sizeof(T)); + _size++; + } + + //! Append s`item` to the vector (unsafe case). + //! + //! Can only be used together with `willGrow()`. If `willGrow(N)` returns `kErrorOk` then N elements + //! can be added to the vector without checking if there is a place for them. Used mostly internally. + ASMJIT_FORCE_INLINE void appendUnsafe(const T& item) noexcept { + ASMJIT_ASSERT(_size < _capacity); + + memcpy(static_cast(static_cast(_data) + _size), + static_cast(&item), + sizeof(T)); + _size++; + } + + //! Inserts an `item` at the specified `index` (unsafe case). + ASMJIT_FORCE_INLINE void insertUnsafe(size_t index, const T& item) noexcept { + ASMJIT_ASSERT(_size < _capacity); + ASMJIT_ASSERT(index <= _size); + + T* dst = static_cast(_data) + index; + memmove(static_cast(dst + 1), + static_cast(dst), + size_t(_size - index) * sizeof(T)); + + memcpy(static_cast(dst), + static_cast(&item), + sizeof(T)); + + _size++; + } + + //! Concatenates all items of `other` at the end of the vector. + ASMJIT_FORCE_INLINE void concatUnsafe(const ZoneVector& other) noexcept { + uint32_t size = other._size; + ASMJIT_ASSERT(_capacity - _size >= size); + + if (size) { + memcpy(static_cast(static_cast(_data) + _size), + static_cast(other._data), + size_t(size) * sizeof(T)); + _size += size; + } + } + + //! Returns index of the given `val` or `Globals::kNotFound` if it doesn't exist. + ASMJIT_FORCE_INLINE uint32_t indexOf(const T& val) const noexcept { + const T* data = static_cast(_data); + uint32_t size = _size; + + for (uint32_t i = 0; i < size; i++) + if (data[i] == val) + return i; + return Globals::kNotFound; + } + + //! Tests whether the vector contains `val`. + inline bool contains(const T& val) const noexcept { + return indexOf(val) != Globals::kNotFound; + } + + //! Removes item at index `i`. + inline void removeAt(size_t i) noexcept { + ASMJIT_ASSERT(i < _size); + + T* data = static_cast(_data) + i; + size_t size = --_size - i; + + if (size) { + memmove(static_cast(data), + static_cast(data + 1), + size_t(size) * sizeof(T)); + } + } + + //! Pops the last element from the vector and returns it. + inline T pop() noexcept { + ASMJIT_ASSERT(_size > 0); + + uint32_t index = --_size; + return data()[index]; + } + + template> + inline void sort(const CompareT& cmp = CompareT()) noexcept { + Support::qSort(data(), size(), cmp); + } + + //! Returns item at index `i`. + inline T& operator[](size_t i) noexcept { + ASMJIT_ASSERT(i < _size); + return data()[i]; + } + + //! Returns item at index `i`. + inline const T& operator[](size_t i) const noexcept { + ASMJIT_ASSERT(i < _size); + return data()[i]; + } + + //! Returns a reference to the first element of the vector. + //! + //! \note The vector must have at least one element. Attempting to use `first()` on empty vector will trigger + //! an assertion failure in debug builds. + ASMJIT_INLINE_NODEBUG T& first() noexcept { return operator[](0); } + //! \overload + ASMJIT_INLINE_NODEBUG const T& first() const noexcept { return operator[](0); } + + //! Returns a reference to the last element of the vector. + //! + //! \note The vector must have at least one element. Attempting to use `last()` on empty vector will trigger + //! an assertion failure in debug builds. + inline T& last() noexcept { return operator[](_size - 1); } + //! \overload + inline const T& last() const noexcept { return operator[](_size - 1); } + + //! \} + + //! \name Memory Management + //! \{ + + //! Releases the memory held by `ZoneVector` back to the `allocator`. + inline void release(ZoneAllocator* allocator) noexcept { + _release(allocator, sizeof(T)); + } + + //! Called to grow the buffer to fit at least `n` elements more. + inline Error grow(ZoneAllocator* allocator, uint32_t n) noexcept { + return ZoneVectorBase::_grow(allocator, sizeof(T), n); + } + + //! Resizes the vector to hold `n` elements. + //! + //! If `n` is greater than the current size then the additional elements' content will be initialized to zero. + //! If `n` is less than the current size then the vector will be truncated to exactly `n` elements. + inline Error resize(ZoneAllocator* allocator, uint32_t n) noexcept { + return ZoneVectorBase::_resize(allocator, sizeof(T), n); + } + + //! Reallocates the internal array to fit at least `n` items. + inline Error reserve(ZoneAllocator* allocator, uint32_t n) noexcept { + return n > _capacity ? ZoneVectorBase::_reserve(allocator, sizeof(T), n) : Error(kErrorOk); + } + + inline Error willGrow(ZoneAllocator* allocator, uint32_t n = 1) noexcept { + return _capacity - _size < n ? grow(allocator, n) : Error(kErrorOk); + } + + //! \} +}; + +//! Zone-allocated bit vector. +class ZoneBitVector { +public: + typedef Support::BitWord BitWord; + + ASMJIT_NONCOPYABLE(ZoneBitVector) + + //! \name Constants + //! \{ + + enum : uint32_t { + kBitWordSizeInBits = Support::kBitWordSizeInBits + }; + + //! \} + + //! \name Members + //! \{ + + //! Bits. + BitWord* _data = nullptr; + //! Size of the bit-vector (in bits). + uint32_t _size = 0; + //! Capacity of the bit-vector (in bits). + uint32_t _capacity = 0; + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + static ASMJIT_INLINE_NODEBUG uint32_t _wordsPerBits(uint32_t nBits) noexcept { + return ((nBits + kBitWordSizeInBits - 1) / kBitWordSizeInBits); + } + + static ASMJIT_INLINE_NODEBUG void _zeroBits(BitWord* dst, uint32_t nBitWords) noexcept { + for (uint32_t i = 0; i < nBitWords; i++) + dst[i] = 0; + } + + static ASMJIT_INLINE_NODEBUG void _fillBits(BitWord* dst, uint32_t nBitWords) noexcept { + for (uint32_t i = 0; i < nBitWords; i++) + dst[i] = ~BitWord(0); + } + + static ASMJIT_INLINE_NODEBUG void _copyBits(BitWord* dst, const BitWord* src, uint32_t nBitWords) noexcept { + for (uint32_t i = 0; i < nBitWords; i++) + dst[i] = src[i]; + } + + //! \} + //! \endcond + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneBitVector() noexcept {} + + ASMJIT_INLINE_NODEBUG ZoneBitVector(ZoneBitVector&& other) noexcept + : _data(other._data), + _size(other._size), + _capacity(other._capacity) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG bool operator==(const ZoneBitVector& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const ZoneBitVector& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the bit-vector is empty (has no bits). + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + //! Returns the size of this bit-vector (in bits). + ASMJIT_INLINE_NODEBUG uint32_t size() const noexcept { return _size; } + //! Returns the capacity of this bit-vector (in bits). + ASMJIT_INLINE_NODEBUG uint32_t capacity() const noexcept { return _capacity; } + + //! Returns the size of the `BitWord[]` array in `BitWord` units. + ASMJIT_INLINE_NODEBUG uint32_t sizeInBitWords() const noexcept { return _wordsPerBits(_size); } + //! Returns the capacity of the `BitWord[]` array in `BitWord` units. + ASMJIT_INLINE_NODEBUG uint32_t capacityInBitWords() const noexcept { return _wordsPerBits(_capacity); } + + //! Returns bit-vector data as `BitWord[]`. + ASMJIT_INLINE_NODEBUG BitWord* data() noexcept { return _data; } + //! \overload + ASMJIT_INLINE_NODEBUG const BitWord* data() const noexcept { return _data; } + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneBitVector& other) noexcept { + std::swap(_data, other._data); + std::swap(_size, other._size); + std::swap(_capacity, other._capacity); + } + + ASMJIT_INLINE_NODEBUG void clear() noexcept { + _size = 0; + } + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _data = nullptr; + _size = 0; + _capacity = 0; + } + + ASMJIT_INLINE_NODEBUG void truncate(uint32_t newSize) noexcept { + _size = Support::min(_size, newSize); + _clearUnusedBits(); + } + + inline bool bitAt(uint32_t index) const noexcept { + ASMJIT_ASSERT(index < _size); + return Support::bitVectorGetBit(_data, index); + } + + inline void setBit(uint32_t index, bool value) noexcept { + ASMJIT_ASSERT(index < _size); + Support::bitVectorSetBit(_data, index, value); + } + + inline void flipBit(uint32_t index) noexcept { + ASMJIT_ASSERT(index < _size); + Support::bitVectorFlipBit(_data, index); + } + + ASMJIT_FORCE_INLINE Error append(ZoneAllocator* allocator, bool value) noexcept { + uint32_t index = _size; + if (ASMJIT_UNLIKELY(index >= _capacity)) + return _append(allocator, value); + + uint32_t idx = index / kBitWordSizeInBits; + uint32_t bit = index % kBitWordSizeInBits; + + if (bit == 0) + _data[idx] = BitWord(value) << bit; + else + _data[idx] |= BitWord(value) << bit; + + _size++; + return kErrorOk; + } + + ASMJIT_API Error copyFrom(ZoneAllocator* allocator, const ZoneBitVector& other) noexcept; + + ASMJIT_FORCE_INLINE void clearAll() noexcept { + _zeroBits(_data, _wordsPerBits(_size)); + } + + ASMJIT_FORCE_INLINE void fillAll() noexcept { + _fillBits(_data, _wordsPerBits(_size)); + _clearUnusedBits(); + } + + ASMJIT_FORCE_INLINE void clearBits(uint32_t start, uint32_t count) noexcept { + ASMJIT_ASSERT(start <= _size); + ASMJIT_ASSERT(_size - start >= count); + + Support::bitVectorClear(_data, start, count); + } + + ASMJIT_FORCE_INLINE void fillBits(uint32_t start, uint32_t count) noexcept { + ASMJIT_ASSERT(start <= _size); + ASMJIT_ASSERT(_size - start >= count); + + Support::bitVectorFill(_data, start, count); + } + + //! Performs a logical bitwise AND between bits specified in this array and bits in `other`. If `other` has less + //! bits than `this` then all remaining bits are set to zero. + //! + //! \note The size of the BitVector is unaffected by this operation. + ASMJIT_FORCE_INLINE void and_(const ZoneBitVector& other) noexcept { + BitWord* dst = _data; + const BitWord* src = other._data; + + uint32_t thisBitWordCount = sizeInBitWords(); + uint32_t otherBitWordCount = other.sizeInBitWords(); + uint32_t commonBitWordCount = Support::min(thisBitWordCount, otherBitWordCount); + + uint32_t i = 0; + while (i < commonBitWordCount) { + dst[i] = dst[i] & src[i]; + i++; + } + + while (i < thisBitWordCount) { + dst[i] = 0; + i++; + } + } + + //! Performs a logical bitwise AND between bits specified in this array and negated bits in `other`. If `other` + //! has less bits than `this` then all remaining bits are kept intact. + //! + //! \note The size of the BitVector is unaffected by this operation. + ASMJIT_FORCE_INLINE void andNot(const ZoneBitVector& other) noexcept { + BitWord* dst = _data; + const BitWord* src = other._data; + + uint32_t commonBitWordCount = _wordsPerBits(Support::min(_size, other._size)); + for (uint32_t i = 0; i < commonBitWordCount; i++) + dst[i] = dst[i] & ~src[i]; + } + + //! Performs a logical bitwise OP between bits specified in this array and bits in `other`. If `other` has less + //! bits than `this` then all remaining bits are kept intact. + //! + //! \note The size of the BitVector is unaffected by this operation. + ASMJIT_FORCE_INLINE void or_(const ZoneBitVector& other) noexcept { + BitWord* dst = _data; + const BitWord* src = other._data; + + uint32_t commonBitWordCount = _wordsPerBits(Support::min(_size, other._size)); + for (uint32_t i = 0; i < commonBitWordCount; i++) + dst[i] = dst[i] | src[i]; + _clearUnusedBits(); + } + + ASMJIT_FORCE_INLINE void _clearUnusedBits() noexcept { + uint32_t idx = _size / kBitWordSizeInBits; + uint32_t bit = _size % kBitWordSizeInBits; + + if (!bit) + return; + _data[idx] &= (BitWord(1) << bit) - 1u; + } + + ASMJIT_FORCE_INLINE bool equals(const ZoneBitVector& other) const noexcept { + if (_size != other._size) + return false; + + const BitWord* aData = _data; + const BitWord* bData = other._data; + uint32_t numBitWords = _wordsPerBits(_size); + + for (uint32_t i = 0; i < numBitWords; i++) + if (aData[i] != bData[i]) + return false; + return true; + } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use ZoneVector::equals() instead") + ASMJIT_FORCE_INLINE bool eq(const ZoneBitVector& other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} + + //! \name Memory Management + //! \{ + + inline void release(ZoneAllocator* allocator) noexcept { + if (!_data) + return; + allocator->release(_data, _capacity / 8); + reset(); + } + + ASMJIT_INLINE_NODEBUG Error resize(ZoneAllocator* allocator, uint32_t newSize, bool newBitsValue = false) noexcept { + return _resize(allocator, newSize, newSize, newBitsValue); + } + + ASMJIT_API Error _resize(ZoneAllocator* allocator, uint32_t newSize, uint32_t idealCapacity, bool newBitsValue) noexcept; + ASMJIT_API Error _append(ZoneAllocator* allocator, bool value) noexcept; + + //! \} + + //! \name Iterators + //! \{ + + class ForEachBitSet : public Support::BitVectorIterator { + public: + inline explicit ForEachBitSet(const ZoneBitVector& bitVector) noexcept + : Support::BitVectorIterator(bitVector.data(), bitVector.sizeInBitWords()) {} + }; + + template + class ForEachBitOp : public Support::BitVectorOpIterator { + public: + inline ForEachBitOp(const ZoneBitVector& a, const ZoneBitVector& b) noexcept + : Support::BitVectorOpIterator(a.data(), b.data(), a.sizeInBitWords()) { + ASMJIT_ASSERT(a.size() == b.size()); + } + }; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONEVECTOR_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86assembler.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86assembler.h new file mode 100644 index 0000000000000000000000000000000000000000..dd980a72692b83226349f4d99f21e30937d9ef23 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86assembler.h @@ -0,0 +1,695 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86ASSEMBLER_H_INCLUDED +#define ASMJIT_X86_X86ASSEMBLER_H_INCLUDED + +#include "../core/assembler.h" +#include "../x86/x86emitter.h" +#include "../x86/x86operand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! X86/X64 assembler implementation. +//! +//! x86::Assembler is a code emitter that emits machine code directly into the \ref CodeBuffer. The assembler is capable +//! of targeting both 32-bit and 64-bit instruction sets, the instruction set can be configured through \ref CodeHolder. +//! +//! ### Basics +//! +//! The following example shows a basic use of `x86::Assembler`, how to generate a function that works in both 32-bit +//! and 64-bit modes, and how to connect \ref JitRuntime, \ref CodeHolder, and `x86::Assembler`. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef int (*SumFunc)(const int* arr, size_t count); +//! +//! int main() { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Decide between 32-bit CDECL, WIN64, and SysV64 calling conventions: +//! // 32-BIT - passed all arguments by stack. +//! // WIN64 - passes first 4 arguments by RCX, RDX, R8, and R9. +//! // UNIX64 - passes first 6 arguments by RDI, RSI, RDX, RCX, R8, and R9. +//! x86::Gp arr, cnt; +//! x86::Gp sum = x86::eax; // Use EAX as 'sum' as it's a return register. +//! +//! if (ASMJIT_ARCH_BITS == 64) { +//! #if defined(_WIN32) +//! arr = x86::rcx; // First argument (array ptr). +//! cnt = x86::rdx; // Second argument (number of elements) +//! #else +//! arr = x86::rdi; // First argument (array ptr). +//! cnt = x86::rsi; // Second argument (number of elements) +//! #endif +//! } +//! else { +//! arr = x86::edx; // Use EDX to hold the array pointer. +//! cnt = x86::ecx; // Use ECX to hold the counter. +//! // Fetch first and second arguments from [ESP + 4] and [ESP + 8]. +//! a.mov(arr, x86::ptr(x86::esp, 4)); +//! a.mov(cnt, x86::ptr(x86::esp, 8)); +//! } +//! +//! Label Loop = a.newLabel(); // To construct the loop, we need some labels. +//! Label Exit = a.newLabel(); +//! +//! a.xor_(sum, sum); // Clear 'sum' register (shorter than 'mov'). +//! a.test(cnt, cnt); // Border case: +//! a.jz(Exit); // If 'cnt' is zero jump to 'Exit' now. +//! +//! a.bind(Loop); // Start of a loop iteration. +//! a.add(sum, x86::dword_ptr(arr)); // Add int at [arr] to 'sum'. +//! a.add(arr, 4); // Increment 'arr' pointer. +//! a.dec(cnt); // Decrease 'cnt'. +//! a.jnz(Loop); // If not zero jump to 'Loop'. +//! +//! a.bind(Exit); // Exit to handle the border case. +//! a.ret(); // Return from function ('sum' == 'eax'). +//! // ----> x86::Assembler is no longer needed from here and can be destroyed <---- +//! +//! SumFunc fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! static const int array[6] = { 4, 8, 15, 16, 23, 42 }; +//! +//! int result = fn(array, 6); // Execute the generated code. +//! printf("%d\n", result); // Print sum of array (108). +//! +//! rt.release(fn); // Explicitly remove the function from the runtime +//! return 0; // Everything successful... +//! } +//! ``` +//! +//! The example should be self-explanatory. It shows how to work with labels, how to use operands, and how to emit +//! instructions that can use different registers based on runtime selection. It implements 32-bit CDECL, WIN64, +//! and SysV64 calling conventions and will work on most X86/X64 environments. +//! +//! Although functions prologs / epilogs can be implemented manually, AsmJit provides utilities that can be used +//! to create function prologs and epilogs automatically, see \ref asmjit_function for more details. +//! +//! ### Instruction Validation +//! +//! Assembler prefers speed over strictness by default. The implementation checks the type of operands and fails +//! if the signature of types is invalid, however, it does only basic checks regarding registers and their groups +//! used in instructions. It's possible to pass operands that don't form any valid signature to the implementation +//! and succeed. This is usually not a problem as Assembler provides typed API so operand types are normally checked +//! by C++ compiler at compile time, however, Assembler is fully dynamic and its \ref emit() function can be called +//! with any instruction id, options, and operands. Moreover, it's also possible to form instructions that will be +//! accepted by the typed API, for example by calling `mov(x86::eax, x86::al)` - the C++ compiler won't see a problem +//! as both EAX and AL are \ref Gp registers. +//! +//! To help with common mistakes AsmJit allows to activate instruction validation. This feature instruments +//! the Assembler to call \ref InstAPI::validate() before it attempts to encode any instruction. +//! +//! The example below illustrates how validation can be turned on: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Enable strict validation. +//! a.addDiagnosticOptions(DiagnosticOptions::kValidateAssembler); +//! +//! // Try to encode invalid or ill-formed instructions. +//! Error err; +//! +//! // Invalid instruction. +//! err = a.mov(x86::eax, x86::al); +//! printf("Status: %s\n", DebugUtils::errorAsString(err)); +//! +//! // Invalid instruction. +//! err = a.emit(x86::Inst::kIdMovss, x86::eax, x86::xmm0); +//! printf("Status: %s\n", DebugUtils::errorAsString(err)); +//! +//! // Ambiguous operand size - the pointer requires size. +//! err = a.inc(x86::ptr(x86::rax)); +//! printf("Status: %s\n", DebugUtils::errorAsString(err)); +//! +//! return 0; +//! } +//! ``` +//! +//! ### Native Registers +//! +//! All emitters provide functions to construct machine-size registers depending on the target. This feature is +//! for users that want to write code targeting both 32-bit and 64-bit architectures at the same time. In AsmJit +//! terminology such registers have prefix `z`, so for example on X86 architecture the following native registers +//! are provided: +//! +//! - `zax` - mapped to either `eax` or `rax` +//! - `zbx` - mapped to either `ebx` or `rbx` +//! - `zcx` - mapped to either `ecx` or `rcx` +//! - `zdx` - mapped to either `edx` or `rdx` +//! - `zsp` - mapped to either `esp` or `rsp` +//! - `zbp` - mapped to either `ebp` or `rbp` +//! - `zsi` - mapped to either `esi` or `rsi` +//! - `zdi` - mapped to either `edi` or `rdi` +//! +//! They are accessible through \ref x86::Assembler, \ref x86::Builder, and \ref x86::Compiler. The example below +//! illustrates how to use this feature: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef int (*Func)(void); +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Let's get these registers from x86::Assembler. +//! x86::Gp zbp = a.zbp(); +//! x86::Gp zsp = a.zsp(); +//! +//! int stackSize = 32; +//! +//! // Function prolog. +//! a.push(zbp); +//! a.mov(zbp, zsp); +//! a.sub(zsp, stackSize); +//! +//! // ... emit some code (this just sets return value to zero) ... +//! a.xor_(x86::eax, x86::eax); +//! +//! // Function epilog and return. +//! a.mov(zsp, zbp); +//! a.pop(zbp); +//! a.ret(); +//! +//! // To make the example complete let's call it. +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! +//! int result = fn(); // Execute the generated code. +//! printf("%d\n", result); // Print the resulting "0". +//! +//! rt.release(fn); // Remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! The example just returns `0`, but the function generated contains a standard prolog and epilog sequence and the +//! function itself reserves 32 bytes of local stack. The advantage is clear - a single code-base can handle multiple +//! targets easily. If you want to create a register of native size dynamically by specifying its id it's also possible: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void example(x86::Assembler& a) { +//! x86::Gp zax = a.gpz(x86::Gp::kIdAx); +//! x86::Gp zbx = a.gpz(x86::Gp::kIdBx); +//! x86::Gp zcx = a.gpz(x86::Gp::kIdCx); +//! x86::Gp zdx = a.gpz(x86::Gp::kIdDx); +//! +//! // You can also change register's id easily. +//! x86::Gp zsp = zax; +//! zsp.setId(4); // or x86::Gp::kIdSp. +//! } +//! ``` +//! +//! ### Data Embedding +//! +//! x86::Assembler extends the standard \ref BaseAssembler with X86/X64 specific conventions that are often used by +//! assemblers to embed data next to the code. The following functions can be used to embed data: +//! +//! - \ref BaseAssembler::embedInt8() - embeds int8_t (portable naming). +//! - \ref BaseAssembler::embedUInt8() - embeds uint8_t (portable naming). +//! - \ref BaseAssembler::embedInt16() - embeds int16_t (portable naming). +//! - \ref BaseAssembler::embedUInt16() - embeds uint16_t (portable naming). +//! - \ref BaseAssembler::embedInt32() - embeds int32_t (portable naming). +//! - \ref BaseAssembler::embedUInt32() - embeds uint32_t (portable naming). +//! - \ref BaseAssembler::embedInt64() - embeds int64_t (portable naming). +//! - \ref BaseAssembler::embedUInt64() - embeds uint64_t (portable naming). +//! - \ref BaseAssembler::embedFloat() - embeds float (portable naming). +//! - \ref BaseAssembler::embedDouble() - embeds double (portable naming). +//! +//! - \ref x86::Assembler::db() - embeds byte (8 bits) (x86 naming). +//! - \ref x86::Assembler::dw() - embeds word (16 bits) (x86 naming). +//! - \ref x86::Assembler::dd() - embeds dword (32 bits) (x86 naming). +//! - \ref x86::Assembler::dq() - embeds qword (64 bits) (x86 naming). +//! +//! The following example illustrates how embed works: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void embedData(x86::Assembler& a) { +//! a.db(0xFF); // Embeds 0xFF byte. +//! a.dw(0xFF00); // Embeds 0xFF00 word (little-endian). +//! a.dd(0xFF000000); // Embeds 0xFF000000 dword (little-endian). +//! a.embedFloat(0.4f); // Embeds 0.4f (32-bit float, little-endian). +//! } +//! ``` +//! +//! Sometimes it's required to read the data that is embedded after code, for example. This can be done through +//! \ref Label as shown below: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void processData(x86::Assembler& a, const Label& L_Data) { +//! x86::Gp addr = a.zax(); // EAX or RAX. +//! x86::Gp val = x86::edi; // Where to store some value... +//! +//! // Approach 1 - Load the address to register through LEA. This approach +//! // is flexible as the address can be then manipulated, for +//! // example if you have a data array, which would need index. +//! a.lea(addr, x86::ptr(L_Data)); +//! a.mov(val, x86::dword_ptr(addr)); +//! +//! // Approach 2 - Load the data directly by using L_Data in address. It's +//! // worth noting that this doesn't work with indexes in X64 +//! // mode. It will use absolute address in 32-bit mode and +//! // relative address (RIP) in 64-bit mode. +//! a.mov(val, x86::dword_ptr(L_Data)); +//! } +//! ``` +//! +//! ### Label Embedding +//! +//! It's also possible to embed labels. In general AsmJit provides the following options: +//! +//! - \ref BaseEmitter::embedLabel() - Embeds absolute address of a label. This is target dependent and would +//! embed either 32-bit or 64-bit data that embeds absolute label address. This kind of embedding cannot be +//! used in a position independent code. +//! +//! - \ref BaseEmitter::embedLabelDelta() - Embeds a difference between two labels. The size of the difference +//! can be specified so it's possible to embed 8-bit, 16-bit, 32-bit, and 64-bit difference, which is sufficient +//! for most purposes. +//! +//! The following example demonstrates how to embed labels and their differences: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void embedLabel(x86::Assembler& a, const Label& L_Data) { +//! // [1] Embed L_Data - the size of the data will be dependent on the target. +//! a.embedLabel(L_Data); +//! +//! // [2] Embed a 32-bit difference of two labels. +//! Label L_Here = a.newLabel(); +//! a.bind(L_Here); +//! // Embeds int32_t(L_Data - L_Here). +//! a.embedLabelDelta(L_Data, L_Here, 4); +//! } +//! ``` +//! +//! ### Using FuncFrame and FuncDetail with x86::Assembler +//! +//! The example below demonstrates how \ref FuncFrame and \ref FuncDetail can be used together with \ref x86::Assembler +//! to generate a function that will use platform dependent calling conventions automatically depending on the target: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef void (*SumIntsFunc)(int* dst, const int* a, const int* b); +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create JIT Runtime. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Decide which registers will be mapped to function arguments. Try changing +//! // registers of dst, src_a, and src_b and see what happens in function's +//! // prolog and epilog. +//! x86::Gp dst = a.zax(); +//! x86::Gp src_a = a.zcx(); +//! x86::Gp src_b = a.zdx(); +//! +//! x86::Xmm vec0 = x86::xmm0; +//! x86::Xmm vec1 = x86::xmm1; +//! +//! // Create/initialize FuncDetail and FuncFrame. +//! FuncDetail func; +//! func.init(FuncSignature::build(), +//! rt.environment()); +//! +//! FuncFrame frame; +//! frame.init(func); +//! +//! // Make XMM0 and XMM1 dirty - RegGroup::kVec describes XMM|YMM|ZMM registers. +//! frame.setDirtyRegs(RegGroup::kVec, Support::bitMask(0, 1)); +//! +//! // Alternatively, if you don't want to use register masks you can pass BaseReg +//! // to addDirtyRegs(). The following code would add both xmm0 and xmm1. +//! frame.addDirtyRegs(x86::xmm0, x86::xmm1); +//! +//! FuncArgsAssignment args(&func); // Create arguments assignment context. +//! args.assignAll(dst, src_a, src_b);// Assign our registers to arguments. +//! args.updateFuncFrame(frame); // Reflect our args in FuncFrame. +//! frame.finalize(); // Finalize the FuncFrame (updates it). +//! +//! a.emitProlog(frame); // Emit function prolog. +//! a.emitArgsAssignment(frame, args);// Assign arguments to registers. +//! a.movdqu(vec0, x86::ptr(src_a)); // Load 4 ints from [src_a] to XMM0. +//! a.movdqu(vec1, x86::ptr(src_b)); // Load 4 ints from [src_b] to XMM1. +//! a.paddd(vec0, vec1); // Add 4 ints in XMM1 to XMM0. +//! a.movdqu(x86::ptr(dst), vec0); // Store the result to [dst]. +//! a.emitEpilog(frame); // Emit function epilog and return. +//! +//! SumIntsFunc fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error case. +//! +//! // Execute the generated function. +//! int inA[4] = { 4, 3, 2, 1 }; +//! int inB[4] = { 1, 5, 2, 8 }; +//! int out[4]; +//! fn(out, inA, inB); +//! +//! // Prints {5 8 4 9} +//! printf("{%d %d %d %d}\n", out[0], out[1], out[2], out[3]); +//! +//! rt.release(fn); +//! return 0; +//! } +//! ``` +//! +//! ### Using x86::Assembler as Code-Patcher +//! +//! This is an advanced topic that is sometimes unavoidable. AsmJit by default appends machine code it generates +//! into a \ref CodeBuffer, however, it also allows to set the offset in \ref CodeBuffer explicitly and to overwrite +//! its content. This technique is extremely dangerous as X86 instructions have variable length (see below), so you +//! should in general only patch code to change instruction's immediate values or some other details not known the +//! at a time the instruction was emitted. A typical scenario that requires code-patching is when you start emitting +//! function and you don't know how much stack you want to reserve for it. +//! +//! Before we go further it's important to introduce instruction options, because they can help with code-patching +//! (and not only patching, but that will be explained in AVX-512 section): +//! +//! - Many general-purpose instructions (especially arithmetic ones) on X86 have multiple encodings - in AsmJit +//! this is usually called 'short form' and 'long form'. +//! +//! - AsmJit always tries to use 'short form' as it makes the resulting machine-code smaller, which is always +//! good - this decision is used by majority of assemblers out there. +//! +//! - AsmJit allows to override the default decision by using `short_()` and `long_()` instruction options to force +//! short or long form, respectively. The most useful is `long_()` as it basically forces AsmJit to always emit +//! the longest form. The `short_()` is not that useful as it's automatic (except jumps to non-bound labels). Note +//! that the underscore after each function name avoids collision with built-in C++ types. +//! +//! To illustrate what short form and long form means in binary let's assume we want to emit "add esp, 16" instruction, +//! which has two possible binary encodings: +//! +//! - `83C410` - This is a short form aka `short add esp, 16` - You can see opcode byte (0x8C), MOD/RM byte (0xC4) +//! and an 8-bit immediate value representing `16`. +//! +//! - `81C410000000` - This is a long form aka `long add esp, 16` - You can see a different opcode byte (0x81), the +//! same Mod/RM byte (0xC4) and a 32-bit immediate in little-endian representing `16`. +//! +//! It should be obvious that patching an existing instruction into an instruction having a different size may create +//! various problems. So it's recommended to be careful and to only patch instructions into instructions having the +//! same size. The example below demonstrates how instruction options can be used to guarantee the size of an +//! instruction by forcing the assembler to use long-form encoding: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef int (*Func)(void); +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Let's get these registers from x86::Assembler. +//! x86::Gp zbp = a.zbp(); +//! x86::Gp zsp = a.zsp(); +//! +//! // Function prolog. +//! a.push(zbp); +//! a.mov(zbp, zsp); +//! +//! // This is where we are gonna patch the code later, so let's get the offset +//! // (the current location) from the beginning of the code-buffer. +//! size_t patchOffset = a.offset(); +//! // Let's just emit 'sub zsp, 0' for now, but don't forget to use LONG form. +//! a.long_().sub(zsp, 0); +//! +//! // ... emit some code (this just sets return value to zero) ... +//! a.xor_(x86::eax, x86::eax); +//! +//! // Function epilog and return. +//! a.mov(zsp, zbp); +//! a.pop(zbp); +//! a.ret(); +//! +//! // Now we know how much stack size we want to reserve. I have chosen 128 +//! // bytes on purpose as it's encodable only in long form that we have used. +//! +//! int stackSize = 128; // Number of bytes to reserve on the stack. +//! a.setOffset(patchOffset); // Move the current cursor to `patchOffset`. +//! a.long_().sub(zsp, stackSize); // Patch the code; don't forget to use LONG form. +//! +//! // Now the code is ready to be called +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! +//! int result = fn(); // Execute the generated code. +//! printf("%d\n", result); // Print the resulting "0". +//! +//! rt.release(fn); // Remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! If you run the example it will just work, because both instructions have the same size. As an experiment you can +//! try removing `long_()` form to see what happens when wrong code is generated. +//! +//! ### Code Patching and REX Prefix +//! +//! In 64-bit mode there is one more thing to worry about when patching code: REX prefix. It's a single byte prefix +//! designed to address registers with ids from 9 to 15 and to override the default width of operation from 32 to 64 +//! bits. AsmJit, like other assemblers, only emits REX prefix when it's necessary. If the patched code only changes +//! the immediate value as shown in the previous example then there is nothing to worry about as it doesn't change +//! the logic behind emitting REX prefix, however, if the patched code changes register id or overrides the operation +//! width then it's important to take care of REX prefix as well. +//! +//! AsmJit contains another instruction option that controls (forces) REX prefix - `rex()`. If you use it the +//! instruction emitted will always use REX prefix even when it's encodable without it. The following list contains +//! some instructions and their binary representations to illustrate when it's emitted: +//! +//! - `__83C410` - `add esp, 16` - 32-bit operation in 64-bit mode doesn't require REX prefix. +//! - `4083C410` - `rex add esp, 16` - 32-bit operation in 64-bit mode with forced REX prefix (0x40). +//! - `4883C410` - `add rsp, 16` - 64-bit operation in 64-bit mode requires REX prefix (0x48). +//! - `4183C410` - `add r12d, 16` - 32-bit operation in 64-bit mode using R12D requires REX prefix (0x41). +//! - `4983C410` - `add r12, 16` - 64-bit operation in 64-bit mode using R12 requires REX prefix (0x49). +//! +//! ### More Prefixes +//! +//! X86 architecture is known for its prefixes. AsmJit supports all prefixes +//! that can affect how the instruction is encoded: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void prefixesExample(x86::Assembler& a) { +//! // Lock prefix for implementing atomics: +//! // lock add dword ptr [rdi], 1 +//! a.lock().add(x86::dword_ptr(x86::rdi), 1); +//! +//! // Similarly, XAcquire/XRelease prefixes are also available: +//! // xacquire add dword ptr [rdi], 1 +//! a.xacquire().add(x86::dword_ptr(x86::rdi), 1); +//! +//! // Rep prefix (see also repe/repz and repne/repnz): +//! // rep movs byte ptr [rdi], byte ptr [rsi] +//! a.rep().movs(x86::byte_ptr(x86::rdi), x86::byte_ptr(x86::rsi)); +//! +//! // Forcing REX prefix in 64-bit mode. +//! // rex mov eax, 1 +//! a.rex().mov(x86::eax, 1); +//! +//! // AVX instruction without forced prefix uses the shortest encoding: +//! // vaddpd xmm0, xmm1, xmm2 -> [C5|F1|58|C2] +//! a.vaddpd(x86::xmm0, x86::xmm1, x86::xmm2); +//! +//! // Forcing VEX3 prefix (AVX): +//! // vex3 vaddpd xmm0, xmm1, xmm2 -> [C4|E1|71|58|C2] +//! a.vex3().vaddpd(x86::xmm0, x86::xmm1, x86::xmm2); +//! +//! // Forcing EVEX prefix (AVX512): +//! // evex vaddpd xmm0, xmm1, xmm2 -> [62|F1|F5|08|58|C2] +//! a.evex().vaddpd(x86::xmm0, x86::xmm1, x86::xmm2); +//! +//! // Some instructions accept prefixes not originally intended to: +//! // rep ret +//! a.rep().ret(); +//! } +//! ``` +//! +//! It's important to understand that prefixes are part of instruction options. When a member function that involves +//! adding a prefix is called the prefix is combined with existing instruction options, which will affect the next +//! instruction generated. +//! +//! ### Generating AVX512 code. +//! +//! x86::Assembler can generate AVX512+ code including the use of opmask registers. Opmask can be specified through +//! \ref x86::Assembler::k() function, which stores it as an extra register, which will be used by the next +//! instruction. AsmJit uses such concept for manipulating instruction options as well. +//! +//! The following AVX512 features are supported: +//! +//! - Opmask selector {k} and zeroing {z}. +//! - Rounding modes {rn|rd|ru|rz} and suppress-all-exceptions {sae} option. +//! - AVX512 broadcasts {1toN}. +//! +//! The following example demonstrates how AVX512 features can be used: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void generateAVX512Code(x86::Assembler& a) { +//! using namespace x86; +//! +//! // Opmask Selectors +//! // ---------------- +//! // +//! // - Opmask / zeroing is part of the instruction options / extraReg. +//! // - k(reg) is like {kreg} in Intel syntax. +//! // - z() is like {z} in Intel syntax. +//! +//! // vaddpd zmm {k1} {z}, zmm1, zmm2 +//! a.k(k1).z().vaddpd(zmm0, zmm1, zmm2); +//! +//! // Memory Broadcasts +//! // ----------------- +//! // +//! // - Broadcast data is part of memory operand. +//! // - Use x86::Mem::_1to2(), x86::Mem::_1to4(), etc..., which returns a new x86::Mem operand with broadcast. +//! +//! // vaddpd zmm0 {k1} {z}, zmm1, [rcx] {1to8} +//! a.k(k1).z().vaddpd(zmm0, zmm1, x86::ptr(rcx)._1to8()); +//! +//! // Embedded Rounding & Suppress-All-Exceptions +//! // ------------------------------------------- +//! // +//! // - Rounding mode and {sae} are part of instruction options. +//! // - Use sae() to enable exception suppression. +//! // - Use rn_sae(), rd_sae(), ru_sae(), and rz_sae() - to enable rounding. +//! // - Embedded rounding implicitly sets {sae} as well, that's why the API +//! // also has sae() suffix, to make it clear. +//! +//! // vcmppd k1, zmm1, zmm2, 0x00 {sae} +//! a.sae().vcmppd(k1, zmm1, zmm2, 0); +//! +//! // vaddpd zmm0, zmm1, zmm2 {rz} +//! a.rz_sae().vaddpd(zmm0, zmm1, zmm2); +//! } +//! ``` +class ASMJIT_VIRTAPI Assembler + : public BaseAssembler, + public EmitterImplicitT { +public: + ASMJIT_NONCOPYABLE(Assembler) + typedef BaseAssembler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Assembler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Assembler() noexcept override; + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + // NOTE: x86::Assembler uses _privateData to store 'address-override' bit that is used to decide whether to emit + // address-override (67H) prefix based on the memory BASE+INDEX registers. It's either `kX86MemInfo_67H_X86` or + // `kX86MemInfo_67H_X64`. + ASMJIT_INLINE_NODEBUG uint32_t _addressOverrideMask() const noexcept { return _privateData; } + ASMJIT_INLINE_NODEBUG void _setAddressOverrideMask(uint32_t m) noexcept { _privateData = m; } + + //! \} + //! \endcond + + //! \cond INTERNAL + //! \name Emit + //! \{ + + ASMJIT_API Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* opExt) override; + + //! \} + //! \endcond + + //! \name Align + //! \{ + + ASMJIT_API Error align(AlignMode alignMode, uint32_t alignment) override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86ASSEMBLER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86builder.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86builder.h new file mode 100644 index 0000000000000000000000000000000000000000..194c1402f1101a81d5c13987c21840bdf3e1dd97 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86builder.h @@ -0,0 +1,354 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86BUILDER_H_INCLUDED +#define ASMJIT_X86_X86BUILDER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_BUILDER + +#include "../core/builder.h" +#include "../x86/x86emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! X86/X64 builder implementation. +//! +//! The code representation used by \ref BaseBuilder is compatible with everything AsmJit provides. Each instruction +//! is stored as \ref InstNode, which contains instruction id, options, and operands. Each instruction emitted will +//! create a new \ref InstNode instance and add it to the current cursor in the double-linked list of nodes. Since +//! the instruction stream used by \ref BaseBuilder can be manipulated, we can rewrite the SumInts example from +//! \ref asmjit_assembler into the following: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef void (*SumIntsFunc)(int* dst, const int* a, const int* b); +//! +//! // Small helper function to print the current content of `cb`. +//! static void dumpCode(BaseBuilder& builder, const char* phase) { +//! String sb; +//! formatOptions formatOptions {}; +//! +//! Formatter::formatNodeList(sb, formatOptions, &builder); +//! printf("%s:\n%s\n", phase, sb.data()); +//! } +//! +//! int main() { +//! JitRuntime rt; // Create JIT Runtime. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Builder cb(&code); // Create and attach x86::Builder to `code`. +//! +//! // Decide which registers will be mapped to function arguments. Try changing registers +//! // of `dst`, `srcA`, and `srcB` and see what happens in function's prolog and epilog. +//! x86::Gp dst = cb.zax(); +//! x86::Gp srcA = cb.zcx(); +//! x86::Gp srcB = cb.zdx(); +//! +//! X86::Xmm vec0 = x86::xmm0; +//! X86::Xmm vec1 = x86::xmm1; +//! +//! // Create and initialize `FuncDetail`. +//! FuncDetail func; +//! func.init(FuncSignature::build()); +//! +//! // Remember prolog insertion point. +//! BaseNode* prologInsertionPoint = cb.cursor(); +//! +//! // Emit function body: +//! cb.movdqu(vec0, x86::ptr(srcA)); // Load 4 ints from [srcA] to XMM0. +//! cb.movdqu(vec1, x86::ptr(srcB)); // Load 4 ints from [srcB] to XMM1. +//! cb.paddd(vec0, vec1); // Add 4 ints in XMM1 to XMM0. +//! cb.movdqu(x86::ptr(dst), vec0); // Store the result to [dst]. +//! +//! // Remember epilog insertion point. +//! BaseNode* epilogInsertionPoint = cb.cursor(); +//! +//! // Let's see what we have now. +//! dumpCode(cb, "Raw Function"); +//! +//! // Now, after we emitted the function body, we can insert the prolog, arguments +//! // allocation, and epilog. This is not possible with using pure x86::Assembler. +//! FuncFrame frame; +//! frame.init(func); +//! +//! // Make XMM0 and XMM1 dirty; RegGroup::kVec describes XMM|YMM|ZMM registers. +//! frame.setDirtyRegs(RegGroup::kVec, IntUtils::mask(0, 1)); +//! +//! FuncArgsAssignment args(&func); // Create arguments assignment context. +//! args.assignAll(dst, srcA, srcB); // Assign our registers to arguments. +//! args.updateFrame(frame); // Reflect our args in FuncFrame. +//! frame.finalize(); // Finalize the FuncFrame (updates it). +//! +//! // Insert function prolog and allocate arguments to registers. +//! cb.setCursor(prologInsertionPoint); +//! cb.emitProlog(frame); +//! cb.emitArgsAssignment(frame, args); +//! +//! // Insert function epilog. +//! cb.setCursor(epilogInsertionPoint); +//! cb.emitEpilog(frame); +//! +//! // Let's see how the function's prolog and epilog looks. +//! dumpCode(cb, "Prolog & Epilog"); +//! +//! // IMPORTANT: Builder requires finalize() to be called to serialize its +//! // code to the Assembler (it automatically creates one if not attached). +//! cb.finalize(); +//! +//! SumIntsFunc fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error case. +//! +//! // Execute the generated function. +//! int inA[4] = { 4, 3, 2, 1 }; +//! int inB[4] = { 1, 5, 2, 8 }; +//! int out[4]; +//! fn(out, inA, inB); +//! +//! // Prints {5 8 4 9} +//! printf("{%d %d %d %d}\n", out[0], out[1], out[2], out[3]); +//! +//! rt.release(fn); // Explicitly remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! When the example is executed it should output the following (this one using AMD64-SystemV ABI): +//! +//! ``` +//! Raw Function: +//! movdqu xmm0, [rcx] +//! movdqu xmm1, [rdx] +//! paddd xmm0, xmm1 +//! movdqu [rax], xmm0 +//! +//! Prolog & Epilog: +//! mov rax, rdi +//! mov rcx, rsi +//! movdqu xmm0, [rcx] +//! movdqu xmm1, [rdx] +//! paddd xmm0, xmm1 +//! movdqu [rax], xmm0 +//! ret +//! +//! {5 8 4 9} +//! ``` +//! +//! The number of use-cases of \ref BaseBuilder is not limited and highly depends on your creativity and experience. +//! The previous example can be easily improved to collect all dirty registers inside the function programmatically +//! and to pass them to \ref FuncFrame::setDirtyRegs(). +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! // NOTE: This function doesn't cover all possible constructs. It ignores instructions that write +//! // to implicit registers that are not part of the operand list. It also counts read-only registers. +//! // Real implementation would be a bit more complicated, but still relatively easy to implement. +//! static void collectDirtyRegs(const BaseNode* first, +//! const BaseNode* last, +//! Support::Array& regMask) { +//! const BaseNode* node = first; +//! while (node) { +//! if (node->actsAsInst()) { +//! const InstNode* inst = node->as(); +//! const Operand* opArray = inst->operands(); +//! +//! for (uint32_t i = 0, opCount = inst->opCount(); i < opCount; i++) { +//! const Operand& op = opArray[i]; +//! if (op.isReg()) { +//! const x86::Reg& reg = op.as(); +//! if (reg.group() <= RegGroup::kMaxVirt) { +//! regMask[reg.group()] |= 1u << reg.id(); +//! } +//! } +//! } +//! } +//! +//! if (node == last) +//! break; +//! node = node->next(); +//! } +//! } +//! +//! static void setDirtyRegsOfFuncFrame(const x86::Builder& builder, FuncFrame& frame) { +//! Support::Array regMask {}; +//! collectDirtyRegs(builder.firstNode(), builder.lastNode(), regMask); +//! +//! // X86/X64 ABIs only require to save GP/XMM registers: +//! frame.setDirtyRegs(RegGroup::kGp, regMask[RegGroup::kGp]); +//! frame.setDirtyRegs(RegGroup::kVec, regMask[RegGroup::kVec]); +//! } +//! ``` +//! +//! ### Casting Between Various Emitters +//! +//! Even when \ref BaseAssembler and \ref BaseBuilder provide the same interface as defined by \ref BaseEmitter their +//! platform dependent variants like \ref x86::Assembler and \ref x86::Builder cannot be interchanged or casted to each +//! other by using a C++ `static_cast<>`. The main reason is the inheritance graph of these classes is different and +//! cast-incompatible, as illustrated below: +//! +//! ``` +//! +--------------+ +=========================+ +//! +----------------------->| x86::Emitter |<--+--# x86::EmitterImplicitT<> #<--+ +//! | +--------------+ | +=========================+ | +//! | (abstract) | (mixin) | +//! | +--------------+ +~~~~~~~~~~~~~~+ | | +//! +-->| BaseAssembler|---->|x86::Assembler|<--+ | +//! | +--------------+ +~~~~~~~~~~~~~~+ | | +//! | (abstract) (final) | | +//! +===============+ | +--------------+ +~~~~~~~~~~~~~~+ | | +//! # BaseEmitter #--+-->| BaseBuilder |--+->| x86::Builder |<--+ | +//! +===============+ +--------------+ | +~~~~~~~~~~~~~~+ | +//! (abstract) (abstract) | (final) | +//! +---------------------+ | +//! | | +//! | +--------------+ +~~~~~~~~~~~~~~+ +=========================+ | +//! +-->| BaseCompiler |---->| x86::Compiler|<-----# x86::EmitterExplicitT<> #---+ +//! +--------------+ +~~~~~~~~~~~~~~+ +=========================+ +//! (abstract) (final) (mixin) +//! ``` +//! +//! The graph basically shows that it's not possible to cast between \ref x86::Assembler and \ref x86::Builder. +//! However, since both share the base interface (\ref BaseEmitter) it's possible to cast them to a class that +//! cannot be instantiated, but defines the same interface - the class is called \ref x86::Emitter and was +//! introduced to make it possible to write a function that can emit to both \ref x86::Assembler and \ref +//! x86::Builder. Note that \ref x86::Emitter cannot be created, it's abstract and has private constructors and +//! destructors; it was only designed to be casted to and used as an interface. +//! +//! Each architecture-specific emitter implements a member function called +//! `as()`, which casts the instance to the architecture +//! specific emitter as illustrated below: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! static void emitSomething(x86::Emitter* e) { +//! e->mov(x86::eax, x86::ebx); +//! } +//! +//! static void assemble(CodeHolder& code, bool useAsm) { +//! if (useAsm) { +//! x86::Assembler assembler(&code); +//! emitSomething(assembler.as()); +//! } +//! else { +//! x86::Builder builder(&code); +//! emitSomething(builder.as()); +//! +//! // NOTE: Builder requires `finalize()` to be called to serialize its +//! // content to Assembler (it automatically creates one if not attached). +//! builder.finalize(); +//! } +//! } +//! ``` +//! +//! The example above shows how to create a function that can emit code to either \ref x86::Assembler or \ref +//! x86::Builder through \ref x86::Emitter, which provides emitter-neutral functionality. \ref x86::Emitter, +//! however, doesn't provide any emitter-specific functionality like `setCursor()`. +//! +//! ### Code Injection and Manipulation +//! +//! \ref BaseBuilder emitter stores its nodes in a double-linked list, which makes it easy to manipulate that +//! list during the code generation or afterwards. Each node is always emitted next to the current cursor and +//! the cursor is advanced to that newly emitted node. The cursor can be retrieved and changed by \ref +//! BaseBuilder::cursor() and \ref BaseBuilder::setCursor(), respectively. +//! +//! The example below demonstrates how to remember a node and inject something +//! next to it. +//! +//! ``` +//! static void example(x86::Builder& builder) { +//! // Emit something, after it returns the cursor would point at the last +//! // emitted node. +//! builder.mov(x86::rax, x86::rdx); // [1] +//! +//! // We can retrieve the node. +//! BaseNode* node = builder.cursor(); +//! +//! // Change the instruction we just emitted, just for fun... +//! if (node->isInst()) { +//! InstNode* inst = node->as(); +//! // Changes the operands at index [1] to RCX. +//! inst->setOp(1, x86::rcx); +//! } +//! +//! // ------------------------- Generate Some Code ------------------------- +//! builder.add(x86::rax, x86::rdx); // [2] +//! builder.shr(x86::rax, 3); // [3] +//! // ---------------------------------------------------------------------- +//! +//! // Now, we know where our node is, and we can simply change the cursor +//! // and start emitting something after it. The setCursor() function +//! // returns the previous cursor, and it's always a good practice to remember +//! // it, because you never know if you are not already injecting the code +//! // somewhere else... +//! BaseNode* oldCursor = builder.setCursor(node); +//! +//! builder.mul(x86::rax, 8); // [4] +//! +//! // Restore the cursor +//! builder.setCursor(oldCursor); +//! } +//! ``` +//! +//! The function above would actually emit the following: +//! +//! ``` +//! mov rax, rcx ; [1] Patched at the beginning. +//! mul rax, 8 ; [4] Injected. +//! add rax, rdx ; [2] Followed [1] initially. +//! shr rax, 3 ; [3] Follows [2]. +//! ``` +class ASMJIT_VIRTAPI Builder + : public BaseBuilder, + public EmitterImplicitT { +public: + ASMJIT_NONCOPYABLE(Builder) + typedef BaseBuilder Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Builder(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Builder() noexcept override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_BUILDER +#endif // ASMJIT_X86_X86BUILDER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86compiler.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..b281e208879cddf92027fe15a52eabb907a0e593 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86compiler.h @@ -0,0 +1,726 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86COMPILER_H_INCLUDED +#define ASMJIT_X86_X86COMPILER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_COMPILER + +#include "../core/compiler.h" +#include "../core/type.h" +#include "../x86/x86emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! X86/X64 compiler implementation. +//! +//! ### Compiler Basics +//! +//! The first \ref x86::Compiler example shows how to generate a function that simply returns an integer value. It's +//! an analogy to the first Assembler example: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef int (*Func)(void); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! cc.addFunc(FuncSignature::build()); // Begin a function of `int fn(void)` signature. +//! +//! x86::Gp vReg = cc.newGpd(); // Create a 32-bit general purpose register. +//! cc.mov(vReg, 1); // Move one to our virtual register `vReg`. +//! cc.ret(vReg); // Return `vReg` from the function. +//! +//! cc.endFunc(); // End of the function body. +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! int result = fn(); // Execute the generated code. +//! printf("%d\n", result); // Print the resulting "1". +//! +//! rt.release(fn); // Explicitly remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! The \ref BaseCompiler::addFunc() and \ref BaseCompiler::endFunc() functions are used to define the function and +//! its end. Both must be called per function, but the body doesn't have to be generated in sequence. An example of +//! generating two functions will be shown later. The next example shows more complicated code that contain a loop +//! and generates a simple memory copy function that uses `uint32_t` items: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef void (*MemCpy32)(uint32_t* dst, const uint32_t* src, size_t count); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! FuncNode* funcNode = cc.addFunc ( // Begin the function of the following signature: +//! FuncSignature::build()); // 3rd argument - size_t (machine reg-size). +//! +//! Label L_Loop = cc.newLabel(); // Start of the loop. +//! Label L_Exit = cc.newLabel(); // Used to exit early. +//! +//! x86::Gp dst = cc.newIntPtr("dst"); // Create `dst` register (destination pointer). +//! x86::Gp src = cc.newIntPtr("src"); // Create `src` register (source pointer). +//! x86::Gp i = cc.newUIntPtr("i"); // Create `i` register (loop counter). +//! +//! funcNode->setArg(0, dst); // Assign `dst` argument. +//! funcNode->setArg(1, src); // Assign `src` argument. +//! funcNode->setArg(2, i); // Assign `i` argument. +//! +//! cc.test(i, i); // Early exit if length is zero. +//! cc.jz(L_Exit); +//! +//! cc.bind(L_Loop); // Bind the beginning of the loop here. +//! +//! x86::Gp tmp = cc.newInt32("tmp"); // Copy a single dword (4 bytes). +//! cc.mov(tmp, x86::dword_ptr(src)); // Load DWORD from [src] address. +//! cc.mov(x86::dword_ptr(dst), tmp); // Store DWORD to [dst] address. +//! +//! cc.add(src, 4); // Increment `src`. +//! cc.add(dst, 4); // Increment `dst`. +//! +//! cc.dec(i); // Loop until `i` is non-zero. +//! cc.jnz(L_Loop); +//! +//! cc.bind(L_Exit); // Label used by early exit. +//! cc.endFunc(); // End of the function body. +//! +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! // Add the generated code to the runtime. +//! MemCpy32 memcpy32; +//! Error err = rt.add(&memcpy32, &code); +//! +//! // Handle a possible error returned by AsmJit. +//! if (err) +//! return 1; +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! // Test the generated code. +//! uint32_t input[6] = { 1, 2, 3, 5, 8, 13 }; +//! uint32_t output[6]; +//! memcpy32(output, input, 6); +//! +//! for (uint32_t i = 0; i < 6; i++) +//! printf("%d\n", output[i]); +//! +//! rt.release(memcpy32); +//! return 0; +//! } +//! ``` +//! +//! ### AVX and AVX-512 +//! +//! AVX and AVX-512 code generation must be explicitly enabled via \ref FuncFrame to work properly. If it's not setup +//! correctly then Prolog & Epilog would use SSE instead of AVX instructions to work with SIMD registers. In addition, +//! Compiler requires explicitly enable AVX-512 via \ref FuncFrame in order to use all 32 SIMD registers. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef void (*Func)(void*); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! FuncNode* funcNode = cc.addFunc(FuncSignature::build()); +//! +//! // Use the following to enable AVX and/or AVX-512. +//! funcNode->frame().setAvxEnabled(); +//! funcNode->frame().setAvx512Enabled(); +//! +//! // Do something with the input pointer. +//! x86::Gp addr = cc.newIntPtr("addr"); +//! x86::Zmm vreg = cc.newZmm("vreg"); +//! +//! funcNode->setArg(0, addr); +//! +//! cc.vmovdqu32(vreg, x86::ptr(addr)); +//! cc.vpaddq(vreg, vreg, vreg); +//! cc.vmovdqu32(x86::ptr(addr), vreg); +//! +//! cc.endFunc(); // End of the function body. +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! // Execute the generated code and print some output. +//! uint64_t data[] = { 1, 2, 3, 4, 5, 6, 7, 8 }; +//! fn(data); +//! printf("%llu\n", (unsigned long long)data[0]); +//! +//! rt.release(fn); // Explicitly remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! ### Recursive Functions +//! +//! It's possible to create more functions by using the same \ref x86::Compiler instance and make links between them. +//! In such case it's important to keep the pointer to \ref FuncNode. +//! +//! The example below creates a simple Fibonacci function that calls itself recursively: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef uint32_t (*Fibonacci)(uint32_t x); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! FuncNode* funcNode = cc.addFunc( // Begin of the Fibonacci function, addFunc() +//! FuncSignature::build()); // Returns a pointer to the FuncNode node. +//! +//! Label L_Exit = cc.newLabel(); // Exit label. +//! x86::Gp x = cc.newUInt32(); // Function x argument. +//! x86::Gp y = cc.newUInt32(); // Temporary. +//! +//! funcNode->setArg(0, x); +//! +//! cc.cmp(x, 3); // Return x if less than 3. +//! cc.jb(L_Exit); +//! +//! cc.mov(y, x); // Make copy of the original x. +//! cc.dec(x); // Decrease x. +//! +//! InvokeNode* invokeNode; // Function invocation: +//! cc.invoke(&invokeNode, // - InvokeNode (output). +//! funcNode->label(), // - Function address or Label. +//! FuncSignature::build()); // - Function signature. +//! +//! invokeNode->setArg(0, x); // Assign x as the first argument. +//! invokeNode->setRet(0, x); // Assign x as a return value as well. +//! +//! cc.add(x, y); // Combine the return value with y. +//! +//! cc.bind(L_Exit); +//! cc.ret(x); // Return x. +//! cc.endFunc(); // End of the function body. +//! +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Fibonacci fib; +//! Error err = rt.add(&fib, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! // Test the generated code. +//! printf("Fib(%u) -> %u\n", 8, fib(8)); +//! +//! rt.release(fib); +//! return 0; +//! } +//! ``` +//! +//! ### Stack Management +//! +//! Function's stack-frame is managed automatically, which is used by the register allocator to spill virtual +//! registers. It also provides an interface to allocate user-defined block of the stack, which can be used as +//! a temporary storage by the generated function. In the following example a stack of 256 bytes size is allocated, +//! filled by bytes starting from 0 to 255 and then iterated again to sum all the values. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef int (*Func)(void); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! cc.addFunc(FuncSignature::build()); // Create a function that returns int. +//! +//! x86::Gp p = cc.newIntPtr("p"); +//! x86::Gp i = cc.newIntPtr("i"); +//! +//! // Allocate 256 bytes on the stack aligned to 4 bytes. +//! x86::Mem stack = cc.newStack(256, 4); +//! +//! x86::Mem stackIdx(stack); // Copy of stack with i added. +//! stackIdx.setIndex(i); // stackIdx <- stack[i]. +//! stackIdx.setSize(1); // stackIdx <- byte ptr stack[i]. +//! +//! // Load a stack address to `p`. This step is purely optional and shows +//! // that `lea` is useful to load a memory operands address (even absolute) +//! // to a general purpose register. +//! cc.lea(p, stack); +//! +//! // Clear i (xor is a C++ keyword, hence 'xor_' is used instead). +//! cc.xor_(i, i); +//! +//! Label L1 = cc.newLabel(); +//! Label L2 = cc.newLabel(); +//! +//! cc.bind(L1); // First loop, fill the stack. +//! cc.mov(stackIdx, i.r8()); // stack[i] = uint8_t(i). +//! +//! cc.inc(i); // i++; +//! cc.cmp(i, 256); // if (i < 256) +//! cc.jb(L1); // goto L1; +//! +//! // Second loop, sum all bytes stored in `stack`. +//! x86::Gp sum = cc.newInt32("sum"); +//! x86::Gp val = cc.newInt32("val"); +//! +//! cc.xor_(i, i); +//! cc.xor_(sum, sum); +//! +//! cc.bind(L2); +//! +//! cc.movzx(val, stackIdx); // val = uint32_t(stack[i]); +//! cc.add(sum, val); // sum += val; +//! +//! cc.inc(i); // i++; +//! cc.cmp(i, 256); // if (i < 256) +//! cc.jb(L2); // goto L2; +//! +//! cc.ret(sum); // Return the `sum` of all values. +//! cc.endFunc(); // End of the function body. +//! +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Func func; +//! Error err = rt.add(&func, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! printf("Func() -> %d\n", func()); // Test the generated code. +//! +//! rt.release(func); +//! return 0; +//! } +//! ``` +//! +//! ### Constant Pool +//! +//! Compiler provides two constant pools for a general purpose code generation: +//! +//! - Local constant pool - Part of \ref FuncNode, can be only used by a single function and added after the +//! function epilog sequence (after `ret` instruction). +//! +//! - Global constant pool - Part of \ref BaseCompiler, flushed at the end of the generated code by \ref +//! BaseEmitter::finalize(). +//! +//! The example below illustrates how a built-in constant pool can be used: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! static void exampleUseOfConstPool(x86::Compiler& cc) { +//! cc.addFunc(FuncSignature::build()); +//! +//! x86::Gp v0 = cc.newGpd("v0"); +//! x86::Gp v1 = cc.newGpd("v1"); +//! +//! x86::Mem c0 = cc.newInt32Const(ConstPoolScope::kLocal, 200); +//! x86::Mem c1 = cc.newInt32Const(ConstPoolScope::kLocal, 33); +//! +//! cc.mov(v0, c0); +//! cc.mov(v1, c1); +//! cc.add(v0, v1); +//! +//! cc.ret(v0); +//! cc.endFunc(); +//! } +//! ``` +//! +//! ### Jump Tables +//! +//! x86::Compiler supports `jmp` instruction with reg/mem operand, which is a commonly used pattern to implement +//! indirect jumps within a function, for example to implement `switch()` statement in a programming languages. +//! By default AsmJit assumes that every basic block can be a possible jump target as it's unable to deduce targets +//! from instruction's operands. This is a very pessimistic default that should be avoided if possible as it's costly +//! and very unfriendly to liveness analysis and register allocation. +//! +//! Instead of relying on such pessimistic default behavior, let's use \ref JumpAnnotation to annotate a jump where +//! all targets are known: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! static void exampleUseOfIndirectJump(x86::Compiler& cc) { +//! FuncNode* funcNode = cc.addFunc(FuncSignature::build()); +//! +//! // Function arguments +//! x86::Xmm a = cc.newXmmSs("a"); +//! x86::Xmm b = cc.newXmmSs("b"); +//! x86::Gp op = cc.newUInt32("op"); +//! +//! x86::Gp target = cc.newIntPtr("target"); +//! x86::Gp offset = cc.newIntPtr("offset"); +//! +//! Label L_Table = cc.newLabel(); +//! Label L_Add = cc.newLabel(); +//! Label L_Sub = cc.newLabel(); +//! Label L_Mul = cc.newLabel(); +//! Label L_Div = cc.newLabel(); +//! Label L_End = cc.newLabel(); +//! +//! funcNode->setArg(0, a); +//! funcNode->setArg(1, b); +//! funcNode->setArg(2, op); +//! +//! // Jump annotation is a building block that allows to annotate all possible targets where `jmp()` can +//! // jump. It then drives the CFG construction and liveness analysis, which impacts register allocation. +//! JumpAnnotation* annotation = cc.newJumpAnnotation(); +//! annotation->addLabel(L_Add); +//! annotation->addLabel(L_Sub); +//! annotation->addLabel(L_Mul); +//! annotation->addLabel(L_Div); +//! +//! // Most likely not the common indirect jump approach, but it +//! // doesn't really matter how final address is calculated. The +//! // most important path using JumpAnnotation with `jmp()`. +//! cc.lea(offset, x86::ptr(L_Table)); +//! if (cc.is64Bit()) +//! cc.movsxd(target, x86::dword_ptr(offset, op.cloneAs(offset), 2)); +//! else +//! cc.mov(target, x86::dword_ptr(offset, op.cloneAs(offset), 2)); +//! cc.add(target, offset); +//! cc.jmp(target, annotation); +//! +//! // Acts like a switch() statement in C. +//! cc.bind(L_Add); +//! cc.addss(a, b); +//! cc.jmp(L_End); +//! +//! cc.bind(L_Sub); +//! cc.subss(a, b); +//! cc.jmp(L_End); +//! +//! cc.bind(L_Mul); +//! cc.mulss(a, b); +//! cc.jmp(L_End); +//! +//! cc.bind(L_Div); +//! cc.divss(a, b); +//! +//! cc.bind(L_End); +//! cc.ret(a); +//! +//! cc.endFunc(); +//! +//! // Relative int32_t offsets of `L_XXX - L_Table`. +//! cc.bind(L_Table); +//! cc.embedLabelDelta(L_Add, L_Table, 4); +//! cc.embedLabelDelta(L_Sub, L_Table, 4); +//! cc.embedLabelDelta(L_Mul, L_Table, 4); +//! cc.embedLabelDelta(L_Div, L_Table, 4); +//! } +//! ``` +class ASMJIT_VIRTAPI Compiler + : public BaseCompiler, + public EmitterExplicitT { +public: + ASMJIT_NONCOPYABLE(Compiler) + typedef BaseCompiler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Compiler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Compiler() noexcept override; + + //! \} + + //! \name Virtual Registers + //! \{ + +#ifndef ASMJIT_NO_LOGGING +# define ASMJIT_NEW_REG_FMT(OUT, PARAM, FORMAT, ARGS) \ + _newRegFmt(&OUT, PARAM, FORMAT, ARGS) +#else +# define ASMJIT_NEW_REG_FMT(OUT, PARAM, FORMAT, ARGS) \ + DebugUtils::unused(FORMAT); \ + DebugUtils::unused(std::forward(args)...); \ + _newReg(&OUT, PARAM) +#endif + +#define ASMJIT_NEW_REG_CUSTOM(FUNC, REG) \ + ASMJIT_INLINE_NODEBUG REG FUNC(TypeId typeId) { \ + REG reg(Globals::NoInit); \ + _newReg(®, typeId); \ + return reg; \ + } \ + \ + template \ + ASMJIT_INLINE_NODEBUG REG FUNC(TypeId typeId, const char* fmt, Args&&... args) { \ + REG reg(Globals::NoInit); \ + ASMJIT_NEW_REG_FMT(reg, typeId, fmt, std::forward(args)...); \ + return reg; \ + } + +#define ASMJIT_NEW_REG_TYPED(FUNC, REG, TYPE_ID) \ + ASMJIT_INLINE_NODEBUG REG FUNC() { \ + REG reg(Globals::NoInit); \ + _newReg(®, TYPE_ID); \ + return reg; \ + } \ + \ + template \ + ASMJIT_INLINE_NODEBUG REG FUNC(const char* fmt, Args&&... args) { \ + REG reg(Globals::NoInit); \ + ASMJIT_NEW_REG_FMT(reg, TYPE_ID, fmt, std::forward(args)...); \ + return reg; \ + } + + template + ASMJIT_INLINE_NODEBUG RegT newSimilarReg(const RegT& ref) { + RegT reg(Globals::NoInit); + _newReg(®, ref); + return reg; + } + + template + ASMJIT_INLINE_NODEBUG RegT newSimilarReg(const RegT& ref, const char* fmt, Args&&... args) { + RegT reg(Globals::NoInit); + ASMJIT_NEW_REG_FMT(reg, ref, fmt, std::forward(args)...); + return reg; + } + + ASMJIT_NEW_REG_CUSTOM(newReg , Reg ) + ASMJIT_NEW_REG_CUSTOM(newGp , Gp ) + ASMJIT_NEW_REG_CUSTOM(newVec , Vec ) + ASMJIT_NEW_REG_CUSTOM(newK , KReg) + + ASMJIT_NEW_REG_TYPED(newInt8 , Gp , TypeId::kInt8) + ASMJIT_NEW_REG_TYPED(newUInt8 , Gp , TypeId::kUInt8) + ASMJIT_NEW_REG_TYPED(newInt16 , Gp , TypeId::kInt16) + ASMJIT_NEW_REG_TYPED(newUInt16 , Gp , TypeId::kUInt16) + ASMJIT_NEW_REG_TYPED(newInt32 , Gp , TypeId::kInt32) + ASMJIT_NEW_REG_TYPED(newUInt32 , Gp , TypeId::kUInt32) + ASMJIT_NEW_REG_TYPED(newInt64 , Gp , TypeId::kInt64) + ASMJIT_NEW_REG_TYPED(newUInt64 , Gp , TypeId::kUInt64) + ASMJIT_NEW_REG_TYPED(newIntPtr , Gp , TypeId::kIntPtr) + ASMJIT_NEW_REG_TYPED(newUIntPtr, Gp , TypeId::kUIntPtr) + + ASMJIT_NEW_REG_TYPED(newGpb , Gp , TypeId::kUInt8) + ASMJIT_NEW_REG_TYPED(newGpw , Gp , TypeId::kUInt16) + ASMJIT_NEW_REG_TYPED(newGpd , Gp , TypeId::kUInt32) + ASMJIT_NEW_REG_TYPED(newGpq , Gp , TypeId::kUInt64) + ASMJIT_NEW_REG_TYPED(newGpz , Gp , TypeId::kUIntPtr) + ASMJIT_NEW_REG_TYPED(newXmm , Xmm , TypeId::kInt32x4) + ASMJIT_NEW_REG_TYPED(newXmmSs , Xmm , TypeId::kFloat32x1) + ASMJIT_NEW_REG_TYPED(newXmmSd , Xmm , TypeId::kFloat64x1) + ASMJIT_NEW_REG_TYPED(newXmmPs , Xmm , TypeId::kFloat32x4) + ASMJIT_NEW_REG_TYPED(newXmmPd , Xmm , TypeId::kFloat64x2) + ASMJIT_NEW_REG_TYPED(newYmm , Ymm , TypeId::kInt32x8) + ASMJIT_NEW_REG_TYPED(newYmmPs , Ymm , TypeId::kFloat32x8) + ASMJIT_NEW_REG_TYPED(newYmmPd , Ymm , TypeId::kFloat64x4) + ASMJIT_NEW_REG_TYPED(newZmm , Zmm , TypeId::kInt32x16) + ASMJIT_NEW_REG_TYPED(newZmmPs , Zmm , TypeId::kFloat32x16) + ASMJIT_NEW_REG_TYPED(newZmmPd , Zmm , TypeId::kFloat64x8) + ASMJIT_NEW_REG_TYPED(newMm , Mm , TypeId::kMmx64) + ASMJIT_NEW_REG_TYPED(newKb , KReg, TypeId::kMask8) + ASMJIT_NEW_REG_TYPED(newKw , KReg, TypeId::kMask16) + ASMJIT_NEW_REG_TYPED(newKd , KReg, TypeId::kMask32) + ASMJIT_NEW_REG_TYPED(newKq , KReg, TypeId::kMask64) + +#undef ASMJIT_NEW_REG_TYPED +#undef ASMJIT_NEW_REG_CUSTOM +#undef ASMJIT_NEW_REG_FMT + + //! \} + + //! \name Stack + //! \{ + + //! Creates a new memory chunk allocated on the current function's stack. + ASMJIT_INLINE_NODEBUG Mem newStack(uint32_t size, uint32_t alignment, const char* name = nullptr) { + Mem m(Globals::NoInit); + _newStack(&m, size, alignment, name); + return m; + } + + //! \} + + //! \name Constants + //! \{ + + //! Put data to a constant-pool and get a memory reference to it. + ASMJIT_INLINE_NODEBUG Mem newConst(ConstPoolScope scope, const void* data, size_t size) { + Mem m(Globals::NoInit); + _newConst(&m, scope, data, size); + return m; + } + + //! Put a BYTE `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newByteConst(ConstPoolScope scope, uint8_t val) noexcept { return newConst(scope, &val, 1); } + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newWordConst(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newDWordConst(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newQWordConst(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt16Const(ConstPoolScope scope, int16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt16Const(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt32Const(ConstPoolScope scope, int32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt32Const(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt64Const(ConstPoolScope scope, int64_t val) noexcept { return newConst(scope, &val, 8); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt64Const(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a SP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newFloatConst(ConstPoolScope scope, float val) noexcept { return newConst(scope, &val, 4); } + //! Put a DP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newDoubleConst(ConstPoolScope scope, double val) noexcept { return newConst(scope, &val, 8); } + + //! \} + + //! \name Instruction Options + //! \{ + + //! Force the compiler to not follow the conditional or unconditional jump. + ASMJIT_INLINE_NODEBUG Compiler& unfollow() noexcept { addInstOptions(InstOptions::kUnfollow); return *this; } + //! Tell the compiler that the destination variable will be overwritten. + ASMJIT_INLINE_NODEBUG Compiler& overwrite() noexcept { addInstOptions(InstOptions::kOverwrite); return *this; } + + //! \} + + //! \name Function Call & Ret Intrinsics + //! \{ + + //! Invoke a function call without `target` type enforcement. + ASMJIT_INLINE_NODEBUG Error invoke_(InvokeNode** out, const Operand_& target, const FuncSignature& signature) { + return addInvokeNode(out, Inst::kIdCall, target, signature); + } + + //! Invoke a function call of the given `target` and `signature` and store the added node to `out`. + //! + //! Creates a new \ref InvokeNode, initializes all the necessary members to match the given function `signature`, + //! adds the node to the compiler, and stores its pointer to `out`. The operation is atomic, if anything fails + //! nullptr is stored in `out` and error code is returned. + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Gp& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Mem& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Label& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Imm& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, uint64_t target, const FuncSignature& signature) { return invoke_(out, Imm(int64_t(target)), signature); } + + //! Return from function. + ASMJIT_INLINE_NODEBUG Error ret() { return addRet(Operand(), Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0) { return addRet(o0, Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0, const BaseReg& o1) { return addRet(o0, o1); } + + //! \} + + //! \name Jump Tables Support + //! \{ + + using EmitterExplicitT::jmp; + + //! Adds a jump to the given `target` with the provided jump `annotation`. + ASMJIT_INLINE_NODEBUG Error jmp(const BaseReg& target, JumpAnnotation* annotation) { return emitAnnotatedJump(Inst::kIdJmp, target, annotation); } + //! \overload + ASMJIT_INLINE_NODEBUG Error jmp(const BaseMem& target, JumpAnnotation* annotation) { return emitAnnotatedJump(Inst::kIdJmp, target, annotation); } + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_COMPILER +#endif // ASMJIT_X86_X86COMPILER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86emitter.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..4855e9957c39f55213708670e76baed0b72b9766 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86emitter.h @@ -0,0 +1,4493 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86EMITTER_H_INCLUDED +#define ASMJIT_X86_X86EMITTER_H_INCLUDED + +#include "../core/emitter.h" +#include "../core/support.h" +#include "../x86/x86globals.h" +#include "../x86/x86operand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +#define ASMJIT_INST_0x(NAME, ID) \ + inline Error NAME() { return _emitter()->_emitI(Inst::kId##ID); } + +#define ASMJIT_INST_1x(NAME, ID, T0) \ + inline Error NAME(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID, o0); } + +#define ASMJIT_INST_1c(NAME, ID, CONV, T0) \ + inline Error NAME(CondCode cc, const T0& o0) { return _emitter()->_emitI(CONV(cc), o0); } \ + inline Error NAME##a(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##a, o0); } \ + inline Error NAME##ae(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ae, o0); } \ + inline Error NAME##b(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##b, o0); } \ + inline Error NAME##be(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##be, o0); } \ + inline Error NAME##c(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##c, o0); } \ + inline Error NAME##e(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##e, o0); } \ + inline Error NAME##g(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##g, o0); } \ + inline Error NAME##ge(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ge, o0); } \ + inline Error NAME##l(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##l, o0); } \ + inline Error NAME##le(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##le, o0); } \ + inline Error NAME##na(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##na, o0); } \ + inline Error NAME##nae(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nae, o0); } \ + inline Error NAME##nb(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nb, o0); } \ + inline Error NAME##nbe(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nbe, o0); } \ + inline Error NAME##nc(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nc, o0); } \ + inline Error NAME##ne(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ne, o0); } \ + inline Error NAME##ng(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ng, o0); } \ + inline Error NAME##nge(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nge, o0); } \ + inline Error NAME##nl(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nl, o0); } \ + inline Error NAME##nle(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nle, o0); } \ + inline Error NAME##no(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##no, o0); } \ + inline Error NAME##np(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##np, o0); } \ + inline Error NAME##ns(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ns, o0); } \ + inline Error NAME##nz(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nz, o0); } \ + inline Error NAME##o(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##o, o0); } \ + inline Error NAME##p(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##p, o0); } \ + inline Error NAME##pe(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##pe, o0); } \ + inline Error NAME##po(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##po, o0); } \ + inline Error NAME##s(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##s, o0); } \ + inline Error NAME##z(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##z, o0); } + +#define ASMJIT_INST_2x(NAME, ID, T0, T1) \ + inline Error NAME(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID, o0, o1); } + +#define ASMJIT_INST_2c(NAME, ID, CONV, T0, T1) \ + inline Error NAME(CondCode cc, const T0& o0, const T1& o1) { return _emitter()->_emitI(CONV(cc), o0, o1); } \ + inline Error NAME##a(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##a, o0, o1); } \ + inline Error NAME##ae(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ae, o0, o1); } \ + inline Error NAME##b(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##b, o0, o1); } \ + inline Error NAME##be(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##be, o0, o1); } \ + inline Error NAME##c(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##c, o0, o1); } \ + inline Error NAME##e(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##e, o0, o1); } \ + inline Error NAME##g(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##g, o0, o1); } \ + inline Error NAME##ge(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ge, o0, o1); } \ + inline Error NAME##l(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##l, o0, o1); } \ + inline Error NAME##le(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##le, o0, o1); } \ + inline Error NAME##na(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##na, o0, o1); } \ + inline Error NAME##nae(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nae, o0, o1); } \ + inline Error NAME##nb(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nb, o0, o1); } \ + inline Error NAME##nbe(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nbe, o0, o1); } \ + inline Error NAME##nc(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nc, o0, o1); } \ + inline Error NAME##ne(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ne, o0, o1); } \ + inline Error NAME##ng(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ng, o0, o1); } \ + inline Error NAME##nge(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nge, o0, o1); } \ + inline Error NAME##nl(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nl, o0, o1); } \ + inline Error NAME##nle(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nle, o0, o1); } \ + inline Error NAME##no(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##no, o0, o1); } \ + inline Error NAME##np(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##np, o0, o1); } \ + inline Error NAME##ns(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ns, o0, o1); } \ + inline Error NAME##nz(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nz, o0, o1); } \ + inline Error NAME##o(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##o, o0, o1); } \ + inline Error NAME##p(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##p, o0, o1); } \ + inline Error NAME##pe(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##pe, o0, o1); } \ + inline Error NAME##po(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##po, o0, o1); } \ + inline Error NAME##s(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##s, o0, o1); } \ + inline Error NAME##z(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##z, o0, o1); } + +#define ASMJIT_INST_3x(NAME, ID, T0, T1, T2) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2); } + +#define ASMJIT_INST_4x(NAME, ID, T0, T1, T2, T3) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3); } + +#define ASMJIT_INST_5x(NAME, ID, T0, T1, T2, T3, T4) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4); } + +#define ASMJIT_INST_6x(NAME, ID, T0, T1, T2, T3, T4, T5) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4, const T5& o5) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4, o5); } + +//! \addtogroup asmjit_x86 +//! \{ + +//! Emitter (X86 - explicit). +template +struct EmitterExplicitT { + //! \cond + + // These typedefs are used to describe implicit operands passed explicitly. + typedef Gp Gp_AL; + typedef Gp Gp_AH; + typedef Gp Gp_CL; + typedef Gp Gp_AX; + typedef Gp Gp_DX; + + typedef Gp Gp_EAX; + typedef Gp Gp_EBX; + typedef Gp Gp_ECX; + typedef Gp Gp_EDX; + + typedef Gp Gp_RAX; + typedef Gp Gp_RBX; + typedef Gp Gp_RCX; + typedef Gp Gp_RDX; + + typedef Gp Gp_ZAX; + typedef Gp Gp_ZBX; + typedef Gp Gp_ZCX; + typedef Gp Gp_ZDX; + + typedef Mem DS_ZAX; // ds:[zax] + typedef Mem DS_ZDI; // ds:[zdi] + typedef Mem ES_ZDI; // es:[zdi] + typedef Mem DS_ZSI; // ds:[zsi] + + typedef Xmm XMM0; + + // These two are unfortunately reported by the sanitizer. We know what we do, however, the sanitizer doesn't. + // I have tried to use reinterpret_cast instead, but that would generate bad code when compiled by MSC. + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG This* _emitter() noexcept { return static_cast(this); } + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG const This* _emitter() const noexcept { return static_cast(this); } + + //! \endcond + + //! \name Native Registers + //! \{ + + //! Returns either GPD or GPQ register of the given `id` depending on the emitter's architecture. + inline Gp gpz(uint32_t id) const noexcept { return Gp(_emitter()->_gpSignature, id); } + + inline Gp zax() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdAx); } + inline Gp zcx() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdCx); } + inline Gp zdx() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdDx); } + inline Gp zbx() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdBx); } + inline Gp zsp() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdSp); } + inline Gp zbp() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdBp); } + inline Gp zsi() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdSi); } + inline Gp zdi() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdDi); } + + //! \} + + //! \name Native Pointers + //! \{ + + //! Creates a target dependent pointer of which base register's id is `baseId`. + inline Mem ptr_base(uint32_t baseId, int32_t off = 0, uint32_t size = 0) const noexcept { + return Mem(OperandSignature::fromOpType(OperandType::kMem) | + OperandSignature::fromMemBaseType(_emitter()->_gpSignature.regType()) | + OperandSignature::fromSize(size), + baseId, 0, off); + } + + inline Mem ptr_zax(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdAx, off, size); } + inline Mem ptr_zcx(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdCx, off, size); } + inline Mem ptr_zdx(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdDx, off, size); } + inline Mem ptr_zbx(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdBx, off, size); } + inline Mem ptr_zsp(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdSp, off, size); } + inline Mem ptr_zbp(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdBp, off, size); } + inline Mem ptr_zsi(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdSi, off, size); } + inline Mem ptr_zdi(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdDi, off, size); } + + //! Creates an `intptr_t` memory operand depending on the current architecture. + inline Mem intptr_ptr(const Gp& base, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Gp& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Gp& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Label& base, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Label& base, const Gp& index, uint32_t shift, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Label& base, const Vec& index, uint32_t shift, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Rip& rip, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(rip, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(uint64_t base) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(uint64_t base, const Gp& index, uint32_t shift = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr_abs(uint64_t base) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, nativeGpSize, OperandSignature::fromValue(Mem::AddrType::kAbs)); + } + //! \overload + inline Mem intptr_ptr_abs(uint64_t base, const Gp& index, uint32_t shift = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, nativeGpSize, OperandSignature::fromValue(Mem::AddrType::kRel)); + } + + //! \} + + //! \name Embed + //! \{ + + //! Embeds 8-bit integer data. + inline Error db(uint8_t x, size_t repeatCount = 1) { return _emitter()->embedUInt8(x, repeatCount); } + //! Embeds 16-bit integer data. + inline Error dw(uint16_t x, size_t repeatCount = 1) { return _emitter()->embedUInt16(x, repeatCount); } + //! Embeds 32-bit integer data. + inline Error dd(uint32_t x, size_t repeatCount = 1) { return _emitter()->embedUInt32(x, repeatCount); } + //! Embeds 64-bit integer data. + inline Error dq(uint64_t x, size_t repeatCount = 1) { return _emitter()->embedUInt64(x, repeatCount); } + + //! Adds data in a given structure instance to the CodeBuffer. + template + inline Error dstruct(const T& x) { return _emitter()->embed(&x, uint32_t(sizeof(T))); } + + //! \} + +protected: + //! \cond + inline This& _addInstOptions(InstOptions options) noexcept { + _emitter()->addInstOptions(options); + return *_emitter(); + } + //! \endcond + +public: + //! \name Short/Long Form Options + //! \{ + + //! Force short form of jmp/jcc instruction. + inline This& short_() noexcept { return _addInstOptions(InstOptions::kShortForm); } + //! Force long form of jmp/jcc instruction. + inline This& long_() noexcept { return _addInstOptions(InstOptions::kLongForm); } + + //! \} + + //! \name Encoding Options + //! \{ + + //! Prefer MOD/RM encoding when both MOD/RM and MOD/MR forms are applicable. + inline This& mod_rm() noexcept { return _addInstOptions(InstOptions::kX86_ModRM); } + + //! Prefer MOD/MR encoding when both MOD/RM and MOD/MR forms are applicable. + inline This& mod_mr() noexcept { return _addInstOptions(InstOptions::kX86_ModMR); } + + //! \} + + //! \name Prefix Options + //! \{ + + //! Condition is likely to be taken (has only benefit on P4). + inline This& taken() noexcept { return _addInstOptions(InstOptions::kTaken); } + //! Condition is unlikely to be taken (has only benefit on P4). + inline This& notTaken() noexcept { return _addInstOptions(InstOptions::kNotTaken); } + + //! Use LOCK prefix. + inline This& lock() noexcept { return _addInstOptions(InstOptions::kX86_Lock); } + //! Use XACQUIRE prefix. + inline This& xacquire() noexcept { return _addInstOptions(InstOptions::kX86_XAcquire); } + //! Use XRELEASE prefix. + inline This& xrelease() noexcept { return _addInstOptions(InstOptions::kX86_XRelease); } + + //! Use BND/REPNE prefix. + //! + //! \note This is the same as using `repne()` or `repnz()` prefix. + inline This& bnd() noexcept { return _addInstOptions(InstOptions::kX86_Repne); } + + //! Use REP/REPZ prefix. + //! + //! \note This is the same as using `repe()` or `repz()` prefix. + inline This& rep(const Gp& zcx) noexcept { + _emitter()->_extraReg.init(zcx); + return _addInstOptions(InstOptions::kX86_Rep); + } + + //! Use REP/REPE prefix. + //! + //! \note This is the same as using `rep()` or `repz()` prefix. + inline This& repe(const Gp& zcx) noexcept { return rep(zcx); } + + //! Use REP/REPE prefix. + //! + //! \note This is the same as using `rep()` or `repe()` prefix. + inline This& repz(const Gp& zcx) noexcept { return rep(zcx); } + + //! Use REPNE prefix. + //! + //! \note This is the same as using `bnd()` or `repnz()` prefix. + inline This& repne(const Gp& zcx) noexcept { + _emitter()->_extraReg.init(zcx); + return _addInstOptions(InstOptions::kX86_Repne); + } + + //! Use REPNE prefix. + //! + //! \note This is the same as using `bnd()` or `repne()` prefix. + inline This& repnz(const Gp& zcx) noexcept { return repne(zcx); } + + //! \} + + //! \name REX Options + //! \{ + + //! Force REX prefix to be emitted even when it's not needed (X86_64). + //! + //! \note Don't use when using high 8-bit registers as REX prefix makes them inaccessible and `x86::Assembler` + //! would fail to encode such instruction. + inline This& rex() noexcept { return _addInstOptions(InstOptions::kX86_Rex); } + + //! Force REX.B prefix (X64) [It exists for special purposes only]. + inline This& rex_b() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeB); } + //! Force REX.X prefix (X64) [It exists for special purposes only]. + inline This& rex_x() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeX); } + //! Force REX.R prefix (X64) [It exists for special purposes only]. + inline This& rex_r() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeR); } + //! Force REX.W prefix (X64) [It exists for special purposes only]. + inline This& rex_w() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeW); } + + //! \} + + //! \name VEX and EVEX Options + //! \{ + + //! Use VEX prefix instead of EVEX prefix (useful to select AVX_VNNI instruction instead of AVX512_VNNI). + inline This& vex() noexcept { return _addInstOptions(InstOptions::kX86_Vex); } + //! Force 3-byte VEX prefix (AVX+). + inline This& vex3() noexcept { return _addInstOptions(InstOptions::kX86_Vex3); } + //! Force 4-byte EVEX prefix (AVX512+). + inline This& evex() noexcept { return _addInstOptions(InstOptions::kX86_Evex); } + + //! \} + + //! \name AVX-512 Options & Masking + //! \{ + + //! Use masking {k} (AVX512+). + inline This& k(const KReg& kreg) noexcept { + _emitter()->_extraReg.init(kreg); + return *_emitter(); + } + + //! Use zeroing instead of merging (AVX512+). + inline This& z() noexcept { return _addInstOptions(InstOptions::kX86_ZMask); } + + //! Suppress all exceptions (AVX512+). + inline This& sae() noexcept { return _addInstOptions(InstOptions::kX86_SAE); } + //! Static rounding mode {rn} (round-to-nearest even) and {sae} (AVX512+). + inline This& rn_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RN_SAE); } + //! Static rounding mode {rd} (round-down, toward -inf) and {sae} (AVX512+). + inline This& rd_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RD_SAE); } + //! Static rounding mode {ru} (round-up, toward +inf) and {sae} (AVX512+). + inline This& ru_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RU_SAE); } + //! Static rounding mode {rz} (round-toward-zero, truncate) and {sae} (AVX512+). + inline This& rz_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RZ_SAE); } + + //! \} + + //! \name Core Instructions + //! \{ + + ASMJIT_INST_2x(adc, Adc, Gp, Gp) // ANY + ASMJIT_INST_2x(adc, Adc, Gp, Mem) // ANY + ASMJIT_INST_2x(adc, Adc, Gp, Imm) // ANY + ASMJIT_INST_2x(adc, Adc, Mem, Gp) // ANY + ASMJIT_INST_2x(adc, Adc, Mem, Imm) // ANY + ASMJIT_INST_2x(add, Add, Gp, Gp) // ANY + ASMJIT_INST_2x(add, Add, Gp, Mem) // ANY + ASMJIT_INST_2x(add, Add, Gp, Imm) // ANY + ASMJIT_INST_2x(add, Add, Mem, Gp) // ANY + ASMJIT_INST_2x(add, Add, Mem, Imm) // ANY + ASMJIT_INST_2x(and_, And, Gp, Gp) // ANY + ASMJIT_INST_2x(and_, And, Gp, Mem) // ANY + ASMJIT_INST_2x(and_, And, Gp, Imm) // ANY + ASMJIT_INST_2x(and_, And, Mem, Gp) // ANY + ASMJIT_INST_2x(and_, And, Mem, Imm) // ANY + ASMJIT_INST_2x(bound, Bound, Gp, Mem) // X86 + ASMJIT_INST_2x(bsf, Bsf, Gp, Gp) // ANY + ASMJIT_INST_2x(bsf, Bsf, Gp, Mem) // ANY + ASMJIT_INST_2x(bsr, Bsr, Gp, Gp) // ANY + ASMJIT_INST_2x(bsr, Bsr, Gp, Mem) // ANY + ASMJIT_INST_1x(bswap, Bswap, Gp) // ANY + ASMJIT_INST_2x(bt, Bt, Gp, Gp) // ANY + ASMJIT_INST_2x(bt, Bt, Gp, Imm) // ANY + ASMJIT_INST_2x(bt, Bt, Mem, Gp) // ANY + ASMJIT_INST_2x(bt, Bt, Mem, Imm) // ANY + ASMJIT_INST_2x(btc, Btc, Gp, Gp) // ANY + ASMJIT_INST_2x(btc, Btc, Gp, Imm) // ANY + ASMJIT_INST_2x(btc, Btc, Mem, Gp) // ANY + ASMJIT_INST_2x(btc, Btc, Mem, Imm) // ANY + ASMJIT_INST_2x(btr, Btr, Gp, Gp) // ANY + ASMJIT_INST_2x(btr, Btr, Gp, Imm) // ANY + ASMJIT_INST_2x(btr, Btr, Mem, Gp) // ANY + ASMJIT_INST_2x(btr, Btr, Mem, Imm) // ANY + ASMJIT_INST_2x(bts, Bts, Gp, Gp) // ANY + ASMJIT_INST_2x(bts, Bts, Gp, Imm) // ANY + ASMJIT_INST_2x(bts, Bts, Mem, Gp) // ANY + ASMJIT_INST_2x(bts, Bts, Mem, Imm) // ANY + ASMJIT_INST_1x(cbw, Cbw, Gp_AX) // ANY [EXPLICIT] AX <- Sign Extend AL + ASMJIT_INST_2x(cdq, Cdq, Gp_EDX, Gp_EAX) // ANY [EXPLICIT] EDX:EAX <- Sign Extend EAX + ASMJIT_INST_1x(cdqe, Cdqe, Gp_EAX) // X64 [EXPLICIT] RAX <- Sign Extend EAX + ASMJIT_INST_2x(cqo, Cqo, Gp_RDX, Gp_RAX) // X64 [EXPLICIT] RDX:RAX <- Sign Extend RAX + ASMJIT_INST_2x(cwd, Cwd, Gp_DX, Gp_AX) // ANY [EXPLICIT] DX:AX <- Sign Extend AX + ASMJIT_INST_1x(cwde, Cwde, Gp_EAX) // ANY [EXPLICIT] EAX <- Sign Extend AX + ASMJIT_INST_1x(call, Call, Gp) // ANY + ASMJIT_INST_1x(call, Call, Mem) // ANY + ASMJIT_INST_1x(call, Call, Label) // ANY + ASMJIT_INST_1x(call, Call, Imm) // ANY + ASMJIT_INST_2c(cmov, Cmov, Inst::cmovccFromCond, Gp, Gp) // CMOV + ASMJIT_INST_2c(cmov, Cmov, Inst::cmovccFromCond, Gp, Mem) // CMOV + ASMJIT_INST_2x(cmp, Cmp, Gp, Gp) // ANY + ASMJIT_INST_2x(cmp, Cmp, Gp, Mem) // ANY + ASMJIT_INST_2x(cmp, Cmp, Gp, Imm) // ANY + ASMJIT_INST_2x(cmp, Cmp, Mem, Gp) // ANY + ASMJIT_INST_2x(cmp, Cmp, Mem, Imm) // ANY + ASMJIT_INST_2x(cmps, Cmps, DS_ZSI, ES_ZDI) // ANY [EXPLICIT] + ASMJIT_INST_3x(cmpxchg, Cmpxchg, Gp, Gp, Gp_ZAX) // I486 [EXPLICIT] + ASMJIT_INST_3x(cmpxchg, Cmpxchg, Mem, Gp, Gp_ZAX) // I486 [EXPLICIT] + ASMJIT_INST_5x(cmpxchg16b, Cmpxchg16b, Mem, Gp_RDX, Gp_RAX, Gp_RCX, Gp_RBX); // CMPXCHG16B [EXPLICIT] m == EDX:EAX ? m <- ECX:EBX + ASMJIT_INST_5x(cmpxchg8b, Cmpxchg8b, Mem, Gp_EDX, Gp_EAX, Gp_ECX, Gp_EBX); // CMPXCHG8B [EXPLICIT] m == RDX:RAX ? m <- RCX:RBX + ASMJIT_INST_1x(dec, Dec, Gp) // ANY + ASMJIT_INST_1x(dec, Dec, Mem) // ANY + ASMJIT_INST_2x(div, Div, Gp, Gp) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / r8 + ASMJIT_INST_2x(div, Div, Gp, Mem) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / m8 + ASMJIT_INST_3x(div, Div, Gp, Gp, Gp) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / r16|r32|r64 + ASMJIT_INST_3x(div, Div, Gp, Gp, Mem) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / m16|m32|m64 + ASMJIT_INST_2x(idiv, Idiv, Gp, Gp) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / r8 + ASMJIT_INST_2x(idiv, Idiv, Gp, Mem) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / m8 + ASMJIT_INST_3x(idiv, Idiv, Gp, Gp, Gp) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / r16|r32|r64 + ASMJIT_INST_3x(idiv, Idiv, Gp, Gp, Mem) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / m16|m32|m64 + ASMJIT_INST_2x(imul, Imul, Gp, Gp) // ANY [EXPLICIT] AX <- AL * r8 | ra <- ra * rb + ASMJIT_INST_2x(imul, Imul, Gp, Mem) // ANY [EXPLICIT] AX <- AL * m8 | ra <- ra * m16|m32|m64 + ASMJIT_INST_3x(imul, Imul, Gp, Gp, Imm) // ANY + ASMJIT_INST_3x(imul, Imul, Gp, Mem, Imm) // ANY + ASMJIT_INST_3x(imul, Imul, Gp, Gp, Gp) // ANY [EXPLICIT] xDX:xAX <- xAX * r16|r32|r64 + ASMJIT_INST_3x(imul, Imul, Gp, Gp, Mem) // ANY [EXPLICIT] xDX:xAX <- xAX * m16|m32|m64 + ASMJIT_INST_1x(inc, Inc, Gp) // ANY + ASMJIT_INST_1x(inc, Inc, Mem) // ANY + ASMJIT_INST_1c(j, J, Inst::jccFromCond, Label) // ANY + ASMJIT_INST_1c(j, J, Inst::jccFromCond, Imm) // ANY + ASMJIT_INST_2x(jecxz, Jecxz, Gp, Label) // ANY [EXPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_2x(jecxz, Jecxz, Gp, Imm) // ANY [EXPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_1x(jmp, Jmp, Gp) // ANY + ASMJIT_INST_1x(jmp, Jmp, Mem) // ANY + ASMJIT_INST_1x(jmp, Jmp, Label) // ANY + ASMJIT_INST_1x(jmp, Jmp, Imm) // ANY + ASMJIT_INST_2x(lcall, Lcall, Imm, Imm) // ANY + ASMJIT_INST_1x(lcall, Lcall, Mem) // ANY + ASMJIT_INST_2x(lea, Lea, Gp, Mem) // ANY + ASMJIT_INST_2x(ljmp, Ljmp, Imm, Imm) // ANY + ASMJIT_INST_1x(ljmp, Ljmp, Mem) // ANY + ASMJIT_INST_2x(lods, Lods, Gp_ZAX, DS_ZSI) // ANY [EXPLICIT] + ASMJIT_INST_2x(loop, Loop, Gp_ZCX, Label) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_2x(loop, Loop, Gp_ZCX, Imm) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_2x(loope, Loope, Gp_ZCX, Label) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_2x(loope, Loope, Gp_ZCX, Imm) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_2x(loopne, Loopne, Gp_ZCX, Label) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_2x(loopne, Loopne, Gp_ZCX, Imm) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_2x(mov, Mov, Gp, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, Mem) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, Imm) // ANY + ASMJIT_INST_2x(mov, Mov, Mem, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Mem, Imm) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, CReg) // ANY + ASMJIT_INST_2x(mov, Mov, CReg, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, DReg) // ANY + ASMJIT_INST_2x(mov, Mov, DReg, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, SReg) // ANY + ASMJIT_INST_2x(mov, Mov, Mem, SReg) // ANY + ASMJIT_INST_2x(mov, Mov, SReg, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, SReg, Mem) // ANY + ASMJIT_INST_2x(movabs, Movabs, Gp, Mem) // X64 + ASMJIT_INST_2x(movabs, Movabs, Gp, Imm) // X64 + ASMJIT_INST_2x(movabs, Movabs, Mem, Gp) // X64 + ASMJIT_INST_2x(movnti, Movnti, Mem, Gp) // SSE2 + ASMJIT_INST_2x(movs, Movs, ES_ZDI, DS_ZSI) // ANY [EXPLICIT] + ASMJIT_INST_2x(movsx, Movsx, Gp, Gp) // ANY + ASMJIT_INST_2x(movsx, Movsx, Gp, Mem) // ANY + ASMJIT_INST_2x(movsxd, Movsxd, Gp, Gp) // X64 + ASMJIT_INST_2x(movsxd, Movsxd, Gp, Mem) // X64 + ASMJIT_INST_2x(movzx, Movzx, Gp, Gp) // ANY + ASMJIT_INST_2x(movzx, Movzx, Gp, Mem) // ANY + ASMJIT_INST_2x(mul, Mul, Gp_AX, Gp) // ANY [EXPLICIT] AX <- AL * r8 + ASMJIT_INST_2x(mul, Mul, Gp_AX, Mem) // ANY [EXPLICIT] AX <- AL * m8 + ASMJIT_INST_3x(mul, Mul, Gp_ZDX, Gp_ZAX, Gp) // ANY [EXPLICIT] xDX:xAX <- xAX * r16|r32|r64 + ASMJIT_INST_3x(mul, Mul, Gp_ZDX, Gp_ZAX, Mem) // ANY [EXPLICIT] xDX:xAX <- xAX * m16|m32|m64 + ASMJIT_INST_1x(neg, Neg, Gp) // ANY + ASMJIT_INST_1x(neg, Neg, Mem) // ANY + ASMJIT_INST_0x(nop, Nop) // ANY + ASMJIT_INST_1x(nop, Nop, Gp) // ANY + ASMJIT_INST_1x(nop, Nop, Mem) // ANY + ASMJIT_INST_2x(nop, Nop, Gp, Gp) // ANY + ASMJIT_INST_2x(nop, Nop, Mem, Gp) // ANY + ASMJIT_INST_1x(not_, Not, Gp) // ANY + ASMJIT_INST_1x(not_, Not, Mem) // ANY + ASMJIT_INST_2x(or_, Or, Gp, Gp) // ANY + ASMJIT_INST_2x(or_, Or, Gp, Mem) // ANY + ASMJIT_INST_2x(or_, Or, Gp, Imm) // ANY + ASMJIT_INST_2x(or_, Or, Mem, Gp) // ANY + ASMJIT_INST_2x(or_, Or, Mem, Imm) // ANY + ASMJIT_INST_1x(pop, Pop, Gp) // ANY + ASMJIT_INST_1x(pop, Pop, Mem) // ANY + ASMJIT_INST_1x(pop, Pop, SReg); // ANY + ASMJIT_INST_0x(popa, Popa) // X86 + ASMJIT_INST_0x(popad, Popad) // X86 + ASMJIT_INST_0x(popf, Popf) // ANY + ASMJIT_INST_0x(popfd, Popfd) // X86 + ASMJIT_INST_0x(popfq, Popfq) // X64 + ASMJIT_INST_1x(push, Push, Gp) // ANY + ASMJIT_INST_1x(push, Push, Mem) // ANY + ASMJIT_INST_1x(push, Push, SReg) // ANY + ASMJIT_INST_1x(push, Push, Imm) // ANY + ASMJIT_INST_0x(pusha, Pusha) // X86 + ASMJIT_INST_0x(pushad, Pushad) // X86 + ASMJIT_INST_0x(pushf, Pushf) // ANY + ASMJIT_INST_0x(pushfd, Pushfd) // X86 + ASMJIT_INST_0x(pushfq, Pushfq) // X64 + ASMJIT_INST_2x(rcl, Rcl, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(rcl, Rcl, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(rcl, Rcl, Gp, Imm) // ANY + ASMJIT_INST_2x(rcl, Rcl, Mem, Imm) // ANY + ASMJIT_INST_2x(rcr, Rcr, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(rcr, Rcr, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(rcr, Rcr, Gp, Imm) // ANY + ASMJIT_INST_2x(rcr, Rcr, Mem, Imm) // ANY + ASMJIT_INST_2x(rol, Rol, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(rol, Rol, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(rol, Rol, Gp, Imm) // ANY + ASMJIT_INST_2x(rol, Rol, Mem, Imm) // ANY + ASMJIT_INST_2x(ror, Ror, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(ror, Ror, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(ror, Ror, Gp, Imm) // ANY + ASMJIT_INST_2x(ror, Ror, Mem, Imm) // ANY + ASMJIT_INST_2x(sbb, Sbb, Gp, Gp) // ANY + ASMJIT_INST_2x(sbb, Sbb, Gp, Mem) // ANY + ASMJIT_INST_2x(sbb, Sbb, Gp, Imm) // ANY + ASMJIT_INST_2x(sbb, Sbb, Mem, Gp) // ANY + ASMJIT_INST_2x(sbb, Sbb, Mem, Imm) // ANY + ASMJIT_INST_2x(sal, Sal, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(sal, Sal, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(sal, Sal, Gp, Imm) // ANY + ASMJIT_INST_2x(sal, Sal, Mem, Imm) // ANY + ASMJIT_INST_2x(sar, Sar, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(sar, Sar, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(sar, Sar, Gp, Imm) // ANY + ASMJIT_INST_2x(sar, Sar, Mem, Imm) // ANY + ASMJIT_INST_2x(scas, Scas, Gp_ZAX, ES_ZDI) // ANY [EXPLICIT] + ASMJIT_INST_1c(set, Set, Inst::setccFromCond, Gp) // ANY + ASMJIT_INST_1c(set, Set, Inst::setccFromCond, Mem) // ANY + ASMJIT_INST_2x(shl, Shl, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(shl, Shl, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(shl, Shl, Gp, Imm) // ANY + ASMJIT_INST_2x(shl, Shl, Mem, Imm) // ANY + ASMJIT_INST_2x(shr, Shr, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(shr, Shr, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(shr, Shr, Gp, Imm) // ANY + ASMJIT_INST_2x(shr, Shr, Mem, Imm) // ANY + ASMJIT_INST_3x(shld, Shld, Gp, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shld, Shld, Mem, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shld, Shld, Gp, Gp, Imm) // ANY + ASMJIT_INST_3x(shld, Shld, Mem, Gp, Imm) // ANY + ASMJIT_INST_3x(shrd, Shrd, Gp, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shrd, Shrd, Mem, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shrd, Shrd, Gp, Gp, Imm) // ANY + ASMJIT_INST_3x(shrd, Shrd, Mem, Gp, Imm) // ANY + ASMJIT_INST_2x(stos, Stos, ES_ZDI, Gp_ZAX) // ANY [EXPLICIT] + ASMJIT_INST_2x(sub, Sub, Gp, Gp) // ANY + ASMJIT_INST_2x(sub, Sub, Gp, Mem) // ANY + ASMJIT_INST_2x(sub, Sub, Gp, Imm) // ANY + ASMJIT_INST_2x(sub, Sub, Mem, Gp) // ANY + ASMJIT_INST_2x(sub, Sub, Mem, Imm) // ANY + ASMJIT_INST_2x(test, Test, Gp, Gp) // ANY + ASMJIT_INST_2x(test, Test, Gp, Imm) // ANY + ASMJIT_INST_2x(test, Test, Mem, Gp) // ANY + ASMJIT_INST_2x(test, Test, Mem, Imm) // ANY + ASMJIT_INST_2x(ud0, Ud0, Gp, Gp) // ANY + ASMJIT_INST_2x(ud0, Ud0, Gp, Mem) // ANY + ASMJIT_INST_2x(ud1, Ud1, Gp, Gp) // ANY + ASMJIT_INST_2x(ud1, Ud1, Gp, Mem) // ANY + ASMJIT_INST_0x(ud2, Ud2) // ANY + ASMJIT_INST_2x(xadd, Xadd, Gp, Gp) // ANY + ASMJIT_INST_2x(xadd, Xadd, Mem, Gp) // ANY + ASMJIT_INST_2x(xchg, Xchg, Gp, Gp) // ANY + ASMJIT_INST_2x(xchg, Xchg, Mem, Gp) // ANY + ASMJIT_INST_2x(xchg, Xchg, Gp, Mem) // ANY + ASMJIT_INST_2x(xor_, Xor, Gp, Gp) // ANY + ASMJIT_INST_2x(xor_, Xor, Gp, Mem) // ANY + ASMJIT_INST_2x(xor_, Xor, Gp, Imm) // ANY + ASMJIT_INST_2x(xor_, Xor, Mem, Gp) // ANY + ASMJIT_INST_2x(xor_, Xor, Mem, Imm) // ANY + + //! \} + + //! \name Core Instructions (Aliases) + //! \{ + + //! The `imul(Gp, Imm)` instruction is an alias of `imul(Gp, Gp, Imm)` instruction. + inline Error imul(const Gp& o0, const Imm& o1) { return _emitter()->_emitI(Inst::kIdImul, o0, o0, o1); } + + //! \} + + //! \name Deprecated 32-bit Instructions + //! \{ + + ASMJIT_INST_1x(aaa, Aaa, Gp) // X86 [EXPLICIT] + ASMJIT_INST_2x(aad, Aad, Gp, Imm) // X86 [EXPLICIT] + ASMJIT_INST_2x(aam, Aam, Gp, Imm) // X86 [EXPLICIT] + ASMJIT_INST_1x(aas, Aas, Gp) // X86 [EXPLICIT] + ASMJIT_INST_1x(daa, Daa, Gp) // X86 [EXPLICIT] + ASMJIT_INST_1x(das, Das, Gp) // X86 [EXPLICIT] + + //! \} + + //! \name ENTER/LEAVE Instructions + //! \{ + + ASMJIT_INST_2x(enter, Enter, Imm, Imm) // ANY + ASMJIT_INST_0x(leave, Leave) // ANY + + //! \} + + //! \name IN/OUT Instructions + //! \{ + + // NOTE: For some reason Doxygen is messed up here and thinks we are in cond. + + ASMJIT_INST_2x(in, In, Gp_ZAX, Imm) // ANY + ASMJIT_INST_2x(in, In, Gp_ZAX, Gp_DX) // ANY + ASMJIT_INST_2x(ins, Ins, ES_ZDI, Gp_DX) // ANY + ASMJIT_INST_2x(out, Out, Imm, Gp_ZAX) // ANY + ASMJIT_INST_2x(out, Out, Gp_DX, Gp_ZAX) // ANY + ASMJIT_INST_2x(outs, Outs, Gp_DX, DS_ZSI) // ANY + + //! \} + + //! \name Clear/Set CF/DF Instructions + //! \{ + + ASMJIT_INST_0x(clc, Clc) // ANY + ASMJIT_INST_0x(cld, Cld) // ANY + ASMJIT_INST_0x(cmc, Cmc) // ANY + ASMJIT_INST_0x(stc, Stc) // ANY + ASMJIT_INST_0x(std, Std) // ANY + + //! \} + + //! \name ADX Instructions + //! \{ + + ASMJIT_INST_2x(adcx, Adcx, Gp, Gp) // ADX + ASMJIT_INST_2x(adcx, Adcx, Gp, Mem) // ADX + ASMJIT_INST_2x(adox, Adox, Gp, Gp) // ADX + ASMJIT_INST_2x(adox, Adox, Gp, Mem) // ADX + + //! \} + + //! \name CPUID Instruction + //! \{ + + ASMJIT_INST_4x(cpuid, Cpuid, Gp_EAX, Gp_EBX, Gp_ECX, Gp_EDX) // I486 [EXPLICIT] EAX:EBX:ECX:EDX <- CPUID[EAX:ECX] + + //! \} + + //! \name LAHF/SAHF Instructions + //! \{ + + ASMJIT_INST_1x(lahf, Lahf, Gp_AH) // LAHFSAHF [EXPLICIT] AH <- EFL + ASMJIT_INST_1x(sahf, Sahf, Gp_AH) // LAHFSAHF [EXPLICIT] EFL <- AH + + //! \} + + //! \name BMI Instructions + //! \{ + + ASMJIT_INST_3x(andn, Andn, Gp, Gp, Gp) // BMI + ASMJIT_INST_3x(andn, Andn, Gp, Gp, Mem) // BMI + ASMJIT_INST_3x(bextr, Bextr, Gp, Gp, Gp) // BMI + ASMJIT_INST_3x(bextr, Bextr, Gp, Mem, Gp) // BMI + ASMJIT_INST_2x(blsi, Blsi, Gp, Gp) // BMI + ASMJIT_INST_2x(blsi, Blsi, Gp, Mem) // BMI + ASMJIT_INST_2x(blsmsk, Blsmsk, Gp, Gp) // BMI + ASMJIT_INST_2x(blsmsk, Blsmsk, Gp, Mem) // BMI + ASMJIT_INST_2x(blsr, Blsr, Gp, Gp) // BMI + ASMJIT_INST_2x(blsr, Blsr, Gp, Mem) // BMI + ASMJIT_INST_2x(tzcnt, Tzcnt, Gp, Gp) // BMI + ASMJIT_INST_2x(tzcnt, Tzcnt, Gp, Mem) // BMI + + //! \} + + //! \name BMI2 Instructions + //! \{ + + ASMJIT_INST_3x(bzhi, Bzhi, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(bzhi, Bzhi, Gp, Mem, Gp) // BMI2 + ASMJIT_INST_4x(mulx, Mulx, Gp, Gp, Gp, Gp_ZDX) // BMI2 [EXPLICIT] + ASMJIT_INST_4x(mulx, Mulx, Gp, Gp, Mem, Gp_ZDX) // BMI2 [EXPLICIT] + ASMJIT_INST_3x(pdep, Pdep, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(pdep, Pdep, Gp, Gp, Mem) // BMI2 + ASMJIT_INST_3x(pext, Pext, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(pext, Pext, Gp, Gp, Mem) // BMI2 + ASMJIT_INST_3x(rorx, Rorx, Gp, Gp, Imm) // BMI2 + ASMJIT_INST_3x(rorx, Rorx, Gp, Mem, Imm) // BMI2 + ASMJIT_INST_3x(sarx, Sarx, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(sarx, Sarx, Gp, Mem, Gp) // BMI2 + ASMJIT_INST_3x(shlx, Shlx, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(shlx, Shlx, Gp, Mem, Gp) // BMI2 + ASMJIT_INST_3x(shrx, Shrx, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(shrx, Shrx, Gp, Mem, Gp) // BMI2 + + //! \} + + //! \name CMPCCXADD Instructions + //! \{ + + ASMJIT_INST_3x(cmpbexadd, Cmpbexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpbxadd, Cmpbxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmplexadd, Cmplexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmplxadd, Cmplxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnbexadd, Cmpnbexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnbxadd, Cmpnbxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnlexadd, Cmpnlexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnlxadd, Cmpnlxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnoxadd, Cmpnoxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnpxadd, Cmpnpxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnsxadd, Cmpnsxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnzxadd, Cmpnzxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpoxadd, Cmpoxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmppxadd, Cmppxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpsxadd, Cmpsxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpzxadd, Cmpzxadd, Mem, Gp, Gp) + + //! \} + + //! \name CacheLine Instructions + //! \{ + + ASMJIT_INST_1x(cldemote, Cldemote, Mem) // CLDEMOTE + ASMJIT_INST_1x(clflush, Clflush, Mem) // CLFLUSH + ASMJIT_INST_1x(clflushopt, Clflushopt, Mem) // CLFLUSH_OPT + ASMJIT_INST_1x(clwb, Clwb, Mem) // CLWB + ASMJIT_INST_1x(clzero, Clzero, DS_ZAX) // CLZERO [EXPLICIT] + + //! \} + + //! \name CRC32 Instructions (SSE4.2) + //! \{ + + ASMJIT_INST_2x(crc32, Crc32, Gp, Gp) // SSE4_2 + ASMJIT_INST_2x(crc32, Crc32, Gp, Mem) // SSE4_2 + + //! \} + + //! \name FENCE Instructions (SSE and SSE2) + //! \{ + + ASMJIT_INST_0x(lfence, Lfence) // SSE2 + ASMJIT_INST_0x(mfence, Mfence) // SSE2 + ASMJIT_INST_0x(sfence, Sfence) // SSE + + //! \} + + //! \name LZCNT Instructions + //! \{ + + ASMJIT_INST_2x(lzcnt, Lzcnt, Gp, Gp) // LZCNT + ASMJIT_INST_2x(lzcnt, Lzcnt, Gp, Mem) // LZCNT + + //! \} + + //! \name MOVBE Instructions + //! \{ + + ASMJIT_INST_2x(movbe, Movbe, Gp, Mem) // MOVBE + ASMJIT_INST_2x(movbe, Movbe, Mem, Gp) // MOVBE + + //! \} + + //! \name MOVDIRI & MOVDIR64B Instructions + //! \{ + + ASMJIT_INST_2x(movdiri, Movdiri, Mem, Gp) // MOVDIRI + ASMJIT_INST_2x(movdir64b, Movdir64b, Mem, Mem) // MOVDIR64B + + //! \} + + //! \name MXCSR Instructions (SSE) + //! \{ + + ASMJIT_INST_1x(ldmxcsr, Ldmxcsr, Mem) // SSE + ASMJIT_INST_1x(stmxcsr, Stmxcsr, Mem) // SSE + + //! \} + + //! \name POPCNT Instructions + //! \{ + + ASMJIT_INST_2x(popcnt, Popcnt, Gp, Gp) // POPCNT + ASMJIT_INST_2x(popcnt, Popcnt, Gp, Mem) // POPCNT + + //! \} + + //! \name PREFETCH Instructions + //! \{ + + ASMJIT_INST_1x(prefetch, Prefetch, Mem) // 3DNOW + ASMJIT_INST_1x(prefetchnta, Prefetchnta, Mem) // SSE + ASMJIT_INST_1x(prefetcht0, Prefetcht0, Mem) // SSE + ASMJIT_INST_1x(prefetcht1, Prefetcht1, Mem) // SSE + ASMJIT_INST_1x(prefetcht2, Prefetcht2, Mem) // SSE + ASMJIT_INST_1x(prefetchw, Prefetchw, Mem) // PREFETCHW + ASMJIT_INST_1x(prefetchwt1, Prefetchwt1, Mem) // PREFETCHW1 + + //! \} + + //! \name PREFETCHI Instructions + //! \{ + + ASMJIT_INST_1x(prefetchit0, Prefetchit0, Mem) + ASMJIT_INST_1x(prefetchit1, Prefetchit1, Mem) + + //! \} + + //! \name RAO_INT Instructions + //! \{ + + ASMJIT_INST_2x(aadd, Aadd, Mem, Gp) + ASMJIT_INST_2x(aand, Aand, Mem, Gp) + ASMJIT_INST_2x(aor, Aor, Mem, Gp) + ASMJIT_INST_2x(axor, Axor, Mem, Gp) + + //! \} + + //! \name RDPID Instruction + //! \{ + + ASMJIT_INST_1x(rdpid, Rdpid, Gp) // RDPID + + //! \} + + //! \name RDPRU/RDPKRU Instructions + //! \{ + + ASMJIT_INST_3x(rdpru, Rdpru, Gp_EDX, Gp_EAX, Gp_ECX) // RDPRU [EXPLICIT] EDX:EAX <- PRU[ECX] + ASMJIT_INST_3x(rdpkru, Rdpkru, Gp_EDX, Gp_EAX, Gp_ECX) // RDPKRU [EXPLICIT] EDX:EAX <- PKRU[ECX] + + //! \} + + //! \name RDTSC/RDTSCP Instructions + //! \{ + + ASMJIT_INST_2x(rdtsc, Rdtsc, Gp_EDX, Gp_EAX) // RDTSC [EXPLICIT] EDX:EAX <- Counter + ASMJIT_INST_3x(rdtscp, Rdtscp, Gp_EDX, Gp_EAX, Gp_ECX) // RDTSCP [EXPLICIT] EDX:EAX:EXC <- Counter + + //! \} + + //! \name SERIALIZE Instruction + //! \{ + + ASMJIT_INST_0x(serialize, Serialize) // SERIALIZE + + //! \} + + //! \name TBM Instructions + //! \{ + + ASMJIT_INST_2x(blcfill, Blcfill, Gp, Gp) // TBM + ASMJIT_INST_2x(blcfill, Blcfill, Gp, Mem) // TBM + ASMJIT_INST_2x(blci, Blci, Gp, Gp) // TBM + ASMJIT_INST_2x(blci, Blci, Gp, Mem) // TBM + ASMJIT_INST_2x(blcic, Blcic, Gp, Gp) // TBM + ASMJIT_INST_2x(blcic, Blcic, Gp, Mem) // TBM + ASMJIT_INST_2x(blcmsk, Blcmsk, Gp, Gp) // TBM + ASMJIT_INST_2x(blcmsk, Blcmsk, Gp, Mem) // TBM + ASMJIT_INST_2x(blcs, Blcs, Gp, Gp) // TBM + ASMJIT_INST_2x(blcs, Blcs, Gp, Mem) // TBM + ASMJIT_INST_2x(blsfill, Blsfill, Gp, Gp) // TBM + ASMJIT_INST_2x(blsfill, Blsfill, Gp, Mem) // TBM + ASMJIT_INST_2x(blsic, Blsic, Gp, Gp) // TBM + ASMJIT_INST_2x(blsic, Blsic, Gp, Mem) // TBM + ASMJIT_INST_2x(t1mskc, T1mskc, Gp, Gp) // TBM + ASMJIT_INST_2x(t1mskc, T1mskc, Gp, Mem) // TBM + ASMJIT_INST_2x(tzmsk, Tzmsk, Gp, Gp) // TBM + ASMJIT_INST_2x(tzmsk, Tzmsk, Gp, Mem) // TBM + + //! \} + + //! \name Other User-Mode Instructions + //! \{ + + ASMJIT_INST_2x(arpl, Arpl, Gp, Gp) // X86 + ASMJIT_INST_2x(arpl, Arpl, Mem, Gp) // X86 + ASMJIT_INST_0x(cli, Cli) // ANY + ASMJIT_INST_0x(getsec, Getsec) // SMX + ASMJIT_INST_1x(int_, Int, Imm) // ANY + ASMJIT_INST_0x(int3, Int3) // ANY + ASMJIT_INST_0x(into, Into) // ANY + ASMJIT_INST_2x(lar, Lar, Gp, Gp) // ANY + ASMJIT_INST_2x(lar, Lar, Gp, Mem) // ANY + ASMJIT_INST_2x(lds, Lds, Gp, Mem) // X86 + ASMJIT_INST_2x(les, Les, Gp, Mem) // X86 + ASMJIT_INST_2x(lfs, Lfs, Gp, Mem) // ANY + ASMJIT_INST_2x(lgs, Lgs, Gp, Mem) // ANY + ASMJIT_INST_2x(lsl, Lsl, Gp, Gp) // ANY + ASMJIT_INST_2x(lsl, Lsl, Gp, Mem) // ANY + ASMJIT_INST_2x(lss, Lss, Gp, Mem) // ANY + ASMJIT_INST_0x(pause, Pause) // SSE2 + ASMJIT_INST_0x(rsm, Rsm) // X86 + ASMJIT_INST_1x(sgdt, Sgdt, Mem) // ANY + ASMJIT_INST_1x(sidt, Sidt, Mem) // ANY + ASMJIT_INST_1x(sldt, Sldt, Gp) // ANY + ASMJIT_INST_1x(sldt, Sldt, Mem) // ANY + ASMJIT_INST_1x(smsw, Smsw, Gp) // ANY + ASMJIT_INST_1x(smsw, Smsw, Mem) // ANY + ASMJIT_INST_0x(sti, Sti) // ANY + ASMJIT_INST_1x(str, Str, Gp) // ANY + ASMJIT_INST_1x(str, Str, Mem) // ANY + ASMJIT_INST_1x(verr, Verr, Gp) // ANY + ASMJIT_INST_1x(verr, Verr, Mem) // ANY + ASMJIT_INST_1x(verw, Verw, Gp) // ANY + ASMJIT_INST_1x(verw, Verw, Mem) // ANY + + //! \} + + //! \name FSGSBASE Instructions + //! \{ + + ASMJIT_INST_1x(rdfsbase, Rdfsbase, Gp) // FSGSBASE + ASMJIT_INST_1x(rdgsbase, Rdgsbase, Gp) // FSGSBASE + ASMJIT_INST_1x(wrfsbase, Wrfsbase, Gp) // FSGSBASE + ASMJIT_INST_1x(wrgsbase, Wrgsbase, Gp) // FSGSBASE + + //! \} + + //! \name FXSR Instructions + //! \{ + + ASMJIT_INST_1x(fxrstor, Fxrstor, Mem) // FXSR + ASMJIT_INST_1x(fxrstor64, Fxrstor64, Mem) // FXSR + ASMJIT_INST_1x(fxsave, Fxsave, Mem) // FXSR + ASMJIT_INST_1x(fxsave64, Fxsave64, Mem) // FXSR + + //! \} + + //! \name XSAVE Instructions + //! \{ + + ASMJIT_INST_3x(xgetbv, Xgetbv, Gp_EDX, Gp_EAX, Gp_ECX) // XSAVE [EXPLICIT] EDX:EAX <- XCR[ECX] + ASMJIT_INST_3x(xrstor, Xrstor, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xrstor64, Xrstor64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xrstors, Xrstors, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xrstors64, Xrstors64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsave, Xsave, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsave64, Xsave64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsavec, Xsavec, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsavec64, Xsavec64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsaveopt, Xsaveopt, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsaveopt64, Xsaveopt64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsaves, Xsaves, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsaves64, Xsaves64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + + //! \} + + //! \name MPX Extensions + //! \{ + + ASMJIT_INST_2x(bndcl, Bndcl, Bnd, Gp) // MPX + ASMJIT_INST_2x(bndcl, Bndcl, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndcn, Bndcn, Bnd, Gp) // MPX + ASMJIT_INST_2x(bndcn, Bndcn, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndcu, Bndcu, Bnd, Gp) // MPX + ASMJIT_INST_2x(bndcu, Bndcu, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndldx, Bndldx, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndmk, Bndmk, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndmov, Bndmov, Bnd, Bnd) // MPX + ASMJIT_INST_2x(bndmov, Bndmov, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndmov, Bndmov, Mem, Bnd) // MPX + ASMJIT_INST_2x(bndstx, Bndstx, Mem, Bnd) // MPX + + //! \} + + //! \name MONITORX Instructions + //! \{ + + ASMJIT_INST_3x(monitorx, Monitorx, Mem, Gp, Gp) // MONITORX + ASMJIT_INST_3x(mwaitx, Mwaitx, Gp, Gp, Gp) // MONITORX + + //! \} + + //! \name MCOMMIT Instruction + //! \{ + + ASMJIT_INST_0x(mcommit, Mcommit) // MCOMMIT + + //! \} + + //! \name PTWRITE Instruction + //! \{ + + ASMJIT_INST_1x(ptwrite, Ptwrite, Gp) // PTWRITE + ASMJIT_INST_1x(ptwrite, Ptwrite, Mem) // PTWRITE + + //! \} + + //! \name ENQCMD Instructions + //! \{ + + ASMJIT_INST_2x(enqcmd, Enqcmd, Mem, Mem) // ENQCMD + ASMJIT_INST_2x(enqcmds, Enqcmds, Mem, Mem) // ENQCMD + + //! \} + + //! \name WAITPKG Instructions + //! \{ + + ASMJIT_INST_3x(tpause, Tpause, Gp, Gp, Gp) // WAITPKG + ASMJIT_INST_1x(umonitor, Umonitor, Mem) // WAITPKG + ASMJIT_INST_3x(umwait, Umwait, Gp, Gp, Gp) // WAITPKG + + //! \} + + //! \name RDRAND & RDSEED Instructions + //! \{ + + ASMJIT_INST_1x(rdrand, Rdrand, Gp) // RDRAND + ASMJIT_INST_1x(rdseed, Rdseed, Gp) // RDSEED + + //! \} + + //! \name LWP Instructions + //! \{ + + ASMJIT_INST_1x(llwpcb, Llwpcb, Gp) // LWP + ASMJIT_INST_3x(lwpins, Lwpins, Gp, Gp, Imm) // LWP + ASMJIT_INST_3x(lwpins, Lwpins, Gp, Mem, Imm) // LWP + ASMJIT_INST_3x(lwpval, Lwpval, Gp, Gp, Imm) // LWP + ASMJIT_INST_3x(lwpval, Lwpval, Gp, Mem, Imm) // LWP + ASMJIT_INST_1x(slwpcb, Slwpcb, Gp) // LWP + + //! \} + + //! \name RTM & TSX Instructions + //! \{ + + ASMJIT_INST_1x(xabort, Xabort, Imm) // RTM + ASMJIT_INST_1x(xbegin, Xbegin, Label) // RTM + ASMJIT_INST_1x(xbegin, Xbegin, Imm) // RTM + ASMJIT_INST_0x(xend, Xend) // RTM + ASMJIT_INST_0x(xtest, Xtest) // TSX + + //! \} + + //! \name TSXLDTRK Instructions + //! \{ + + ASMJIT_INST_0x(xresldtrk, Xresldtrk) // TSXLDTRK + ASMJIT_INST_0x(xsusldtrk, Xsusldtrk) // TSXLDTRK + + //! \} + + //! \name CET-IBT Instructions + //! \{ + + ASMJIT_INST_0x(endbr32, Endbr32) // CET_IBT + ASMJIT_INST_0x(endbr64, Endbr64) // CET_IBT + + //! \} + + //! \name CET-SS Instructions + //! \{ + + ASMJIT_INST_1x(clrssbsy, Clrssbsy, Mem) // CET_SS + ASMJIT_INST_0x(setssbsy, Setssbsy) // CET_SS + + ASMJIT_INST_1x(rstorssp, Rstorssp, Mem) // CET_SS + ASMJIT_INST_0x(saveprevssp, Saveprevssp) // CET_SS + + ASMJIT_INST_1x(incsspd, Incsspd, Gp) // CET_SS + ASMJIT_INST_1x(incsspq, Incsspq, Gp) // CET_SS + ASMJIT_INST_1x(rdsspd, Rdsspd, Gp) // CET_SS + ASMJIT_INST_1x(rdsspq, Rdsspq, Gp) // CET_SS + ASMJIT_INST_2x(wrssd, Wrssd, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrssd, Wrssd, Mem, Gp) // CET_SS + ASMJIT_INST_2x(wrssq, Wrssq, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrssq, Wrssq, Mem, Gp) // CET_SS + ASMJIT_INST_2x(wrussd, Wrussd, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrussd, Wrussd, Mem, Gp) // CET_SS + ASMJIT_INST_2x(wrussq, Wrussq, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrussq, Wrussq, Mem, Gp) // CET_SS + + //! \} + + //! \name HRESET Instructions + //! \{ + + ASMJIT_INST_2x(hreset, Hreset, Imm, Gp) // HRESET + + //! \} + + //! \name UINTR Instructions + //! \{ + + ASMJIT_INST_0x(clui, Clui) // UINTR + ASMJIT_INST_1x(senduipi, Senduipi, Gp) // UINTR + ASMJIT_INST_0x(testui, Testui) // UINTR + ASMJIT_INST_0x(stui, Stui) // UINTR + ASMJIT_INST_0x(uiret, Uiret) // UINTR + + //! \} + + //! \name Core Privileged Instructions + //! \{ + + ASMJIT_INST_0x(clts, Clts) // ANY + ASMJIT_INST_0x(hlt, Hlt) // ANY + ASMJIT_INST_0x(invd, Invd) // ANY + ASMJIT_INST_1x(invlpg, Invlpg, Mem) // ANY + ASMJIT_INST_2x(invpcid, Invpcid, Gp, Mem) // ANY + ASMJIT_INST_1x(lgdt, Lgdt, Mem) // ANY + ASMJIT_INST_1x(lidt, Lidt, Mem) // ANY + ASMJIT_INST_1x(lldt, Lldt, Gp) // ANY + ASMJIT_INST_1x(lldt, Lldt, Mem) // ANY + ASMJIT_INST_1x(lmsw, Lmsw, Gp) // ANY + ASMJIT_INST_1x(lmsw, Lmsw, Mem) // ANY + ASMJIT_INST_1x(ltr, Ltr, Gp) // ANY + ASMJIT_INST_1x(ltr, Ltr, Mem) // ANY + ASMJIT_INST_3x(rdmsr, Rdmsr, Gp_EDX, Gp_EAX, Gp_ECX) // MSR [EXPLICIT] RDX:EAX <- MSR[ECX] + ASMJIT_INST_3x(rdpmc, Rdpmc, Gp_EDX, Gp_EAX, Gp_ECX) // ANY [EXPLICIT] RDX:EAX <- PMC[ECX] + ASMJIT_INST_0x(swapgs, Swapgs) // X64 + ASMJIT_INST_0x(wbinvd, Wbinvd) // ANY + ASMJIT_INST_0x(wbnoinvd, Wbnoinvd) // WBNOINVD + ASMJIT_INST_3x(wrmsr, Wrmsr, Gp_EDX, Gp_EAX, Gp_ECX) // MSR [EXPLICIT] RDX:EAX -> MSR[ECX] + ASMJIT_INST_3x(xsetbv, Xsetbv, Gp_EDX, Gp_EAX, Gp_ECX) // XSAVE [EXPLICIT] XCR[ECX] <- EDX:EAX + + //! \} + + //! \name INVLPGB Instructions + //! \{ + + ASMJIT_INST_3x(invlpgb, Invlpgb, Gp_EAX, Gp_EDX, Gp_ECX) + ASMJIT_INST_0x(tlbsync, Tlbsync) + + //! \} + + //! \name MONITOR Instructions (Privileged) + //! \{ + + ASMJIT_INST_3x(monitor, Monitor, Mem, Gp, Gp) // MONITOR + ASMJIT_INST_2x(mwait, Mwait, Gp, Gp) // MONITOR + + //! \} + + //! \name SMAP Instructions (Privileged) + //! \{ + + ASMJIT_INST_0x(clac, Clac) // SMAP + ASMJIT_INST_0x(stac, Stac) // SMAP + + //! \} + + //! \name SKINIT Instructions (Privileged) + //! \{ + + ASMJIT_INST_1x(skinit, Skinit, Gp) // SKINIT [EXPLICIT] + ASMJIT_INST_0x(stgi, Stgi) // SKINIT + + //! \} + + //! \name SNP Instructions (Privileged) + //! \{ + + ASMJIT_INST_0x(psmash, Psmash) // SNP + ASMJIT_INST_0x(pvalidate, Pvalidate) // SNP + ASMJIT_INST_0x(rmpadjust, Rmpadjust) // SNP + ASMJIT_INST_0x(rmpupdate, Rmpupdate) // SNP + + //! \} + + //! \name VMX Instructions (All privileged except vmfunc) + //! \{ + + ASMJIT_INST_2x(invept, Invept, Gp, Mem) // VMX + ASMJIT_INST_2x(invvpid, Invvpid, Gp, Mem) // VMX + ASMJIT_INST_0x(vmcall, Vmcall) // VMX + ASMJIT_INST_1x(vmclear, Vmclear, Mem) // VMX + ASMJIT_INST_0x(vmfunc, Vmfunc) // VMX + ASMJIT_INST_0x(vmlaunch, Vmlaunch) // VMX + ASMJIT_INST_1x(vmptrld, Vmptrld, Mem) // VMX + ASMJIT_INST_1x(vmptrst, Vmptrst, Mem) // VMX + ASMJIT_INST_2x(vmread, Vmread, Gp, Gp) // VMX + ASMJIT_INST_2x(vmread, Vmread, Mem, Gp) // VMX + ASMJIT_INST_0x(vmresume, Vmresume) // VMX + ASMJIT_INST_2x(vmwrite, Vmwrite, Gp, Mem) // VMX + ASMJIT_INST_2x(vmwrite, Vmwrite, Gp, Gp) // VMX + ASMJIT_INST_0x(vmxoff, Vmxoff) // VMX + ASMJIT_INST_1x(vmxon, Vmxon, Mem) // VMX + + //! \} + + //! \name SVM Instructions (All privileged except vmmcall) + //! \{ + + ASMJIT_INST_0x(clgi, Clgi) // SVM + ASMJIT_INST_2x(invlpga, Invlpga, Gp, Gp) // SVM [EXPLICIT] + ASMJIT_INST_1x(vmload, Vmload, Gp) // SVM [EXPLICIT] + ASMJIT_INST_0x(vmmcall, Vmmcall) // SVM + ASMJIT_INST_1x(vmrun, Vmrun, Gp) // SVM [EXPLICIT] + ASMJIT_INST_1x(vmsave, Vmsave, Gp) // SVM [EXPLICIT] + + //! \} + + //! \name SEV_ES Instructions + //! \{ + + ASMJIT_INST_0x(vmgexit, Vmgexit) + + //! \} + + //! \name FPU Instructions + //! \{ + + ASMJIT_INST_0x(f2xm1, F2xm1) // FPU + ASMJIT_INST_0x(fabs, Fabs) // FPU + ASMJIT_INST_2x(fadd, Fadd, St, St) // FPU + ASMJIT_INST_1x(fadd, Fadd, Mem) // FPU + ASMJIT_INST_1x(faddp, Faddp, St) // FPU + ASMJIT_INST_0x(faddp, Faddp) // FPU + ASMJIT_INST_1x(fbld, Fbld, Mem) // FPU + ASMJIT_INST_1x(fbstp, Fbstp, Mem) // FPU + ASMJIT_INST_0x(fchs, Fchs) // FPU + ASMJIT_INST_0x(fclex, Fclex) // FPU + ASMJIT_INST_1x(fcmovb, Fcmovb, St) // FPU + ASMJIT_INST_1x(fcmovbe, Fcmovbe, St) // FPU + ASMJIT_INST_1x(fcmove, Fcmove, St) // FPU + ASMJIT_INST_1x(fcmovnb, Fcmovnb, St) // FPU + ASMJIT_INST_1x(fcmovnbe, Fcmovnbe, St) // FPU + ASMJIT_INST_1x(fcmovne, Fcmovne, St) // FPU + ASMJIT_INST_1x(fcmovnu, Fcmovnu, St) // FPU + ASMJIT_INST_1x(fcmovu, Fcmovu, St) // FPU + ASMJIT_INST_1x(fcom, Fcom, St) // FPU + ASMJIT_INST_0x(fcom, Fcom) // FPU + ASMJIT_INST_1x(fcom, Fcom, Mem) // FPU + ASMJIT_INST_1x(fcomp, Fcomp, St) // FPU + ASMJIT_INST_0x(fcomp, Fcomp) // FPU + ASMJIT_INST_1x(fcomp, Fcomp, Mem) // FPU + ASMJIT_INST_0x(fcompp, Fcompp) // FPU + ASMJIT_INST_1x(fcomi, Fcomi, St) // FPU + ASMJIT_INST_1x(fcomip, Fcomip, St) // FPU + ASMJIT_INST_0x(fcos, Fcos) // FPU + ASMJIT_INST_0x(fdecstp, Fdecstp) // FPU + ASMJIT_INST_2x(fdiv, Fdiv, St, St) // FPU + ASMJIT_INST_1x(fdiv, Fdiv, Mem) // FPU + ASMJIT_INST_1x(fdivp, Fdivp, St) // FPU + ASMJIT_INST_0x(fdivp, Fdivp) // FPU + ASMJIT_INST_2x(fdivr, Fdivr, St, St) // FPU + ASMJIT_INST_1x(fdivr, Fdivr, Mem) // FPU + ASMJIT_INST_1x(fdivrp, Fdivrp, St) // FPU + ASMJIT_INST_0x(fdivrp, Fdivrp) // FPU + ASMJIT_INST_1x(ffree, Ffree, St) // FPU + ASMJIT_INST_1x(fiadd, Fiadd, Mem) // FPU + ASMJIT_INST_1x(ficom, Ficom, Mem) // FPU + ASMJIT_INST_1x(ficomp, Ficomp, Mem) // FPU + ASMJIT_INST_1x(fidiv, Fidiv, Mem) // FPU + ASMJIT_INST_1x(fidivr, Fidivr, Mem) // FPU + ASMJIT_INST_1x(fild, Fild, Mem) // FPU + ASMJIT_INST_1x(fimul, Fimul, Mem) // FPU + ASMJIT_INST_0x(fincstp, Fincstp) // FPU + ASMJIT_INST_0x(finit, Finit) // FPU + ASMJIT_INST_1x(fisub, Fisub, Mem) // FPU + ASMJIT_INST_1x(fisubr, Fisubr, Mem) // FPU + ASMJIT_INST_0x(fninit, Fninit) // FPU + ASMJIT_INST_1x(fist, Fist, Mem) // FPU + ASMJIT_INST_1x(fistp, Fistp, Mem) // FPU + ASMJIT_INST_1x(fisttp, Fisttp, Mem) // FPU+SSE3 + ASMJIT_INST_1x(fld, Fld, Mem) // FPU + ASMJIT_INST_1x(fld, Fld, St) // FPU + ASMJIT_INST_0x(fld1, Fld1) // FPU + ASMJIT_INST_0x(fldl2t, Fldl2t) // FPU + ASMJIT_INST_0x(fldl2e, Fldl2e) // FPU + ASMJIT_INST_0x(fldpi, Fldpi) // FPU + ASMJIT_INST_0x(fldlg2, Fldlg2) // FPU + ASMJIT_INST_0x(fldln2, Fldln2) // FPU + ASMJIT_INST_0x(fldz, Fldz) // FPU + ASMJIT_INST_1x(fldcw, Fldcw, Mem) // FPU + ASMJIT_INST_1x(fldenv, Fldenv, Mem) // FPU + ASMJIT_INST_2x(fmul, Fmul, St, St) // FPU + ASMJIT_INST_1x(fmul, Fmul, Mem) // FPU + ASMJIT_INST_1x(fmulp, Fmulp, St) // FPU + ASMJIT_INST_0x(fmulp, Fmulp) // FPU + ASMJIT_INST_0x(fnclex, Fnclex) // FPU + ASMJIT_INST_0x(fnop, Fnop) // FPU + ASMJIT_INST_1x(fnsave, Fnsave, Mem) // FPU + ASMJIT_INST_1x(fnstenv, Fnstenv, Mem) // FPU + ASMJIT_INST_1x(fnstcw, Fnstcw, Mem) // FPU + ASMJIT_INST_0x(fpatan, Fpatan) // FPU + ASMJIT_INST_0x(fprem, Fprem) // FPU + ASMJIT_INST_0x(fprem1, Fprem1) // FPU + ASMJIT_INST_0x(fptan, Fptan) // FPU + ASMJIT_INST_0x(frndint, Frndint) // FPU + ASMJIT_INST_1x(frstor, Frstor, Mem) // FPU + ASMJIT_INST_1x(fsave, Fsave, Mem) // FPU + ASMJIT_INST_0x(fscale, Fscale) // FPU + ASMJIT_INST_0x(fsin, Fsin) // FPU + ASMJIT_INST_0x(fsincos, Fsincos) // FPU + ASMJIT_INST_0x(fsqrt, Fsqrt) // FPU + ASMJIT_INST_1x(fst, Fst, Mem) // FPU + ASMJIT_INST_1x(fst, Fst, St) // FPU + ASMJIT_INST_1x(fstp, Fstp, Mem) // FPU + ASMJIT_INST_1x(fstp, Fstp, St) // FPU + ASMJIT_INST_1x(fstcw, Fstcw, Mem) // FPU + ASMJIT_INST_1x(fstenv, Fstenv, Mem) // FPU + ASMJIT_INST_2x(fsub, Fsub, St, St) // FPU + ASMJIT_INST_1x(fsub, Fsub, Mem) // FPU + ASMJIT_INST_1x(fsubp, Fsubp, St) // FPU + ASMJIT_INST_0x(fsubp, Fsubp) // FPU + ASMJIT_INST_2x(fsubr, Fsubr, St, St) // FPU + ASMJIT_INST_1x(fsubr, Fsubr, Mem) // FPU + ASMJIT_INST_1x(fsubrp, Fsubrp, St) // FPU + ASMJIT_INST_0x(fsubrp, Fsubrp) // FPU + ASMJIT_INST_0x(ftst, Ftst) // FPU + ASMJIT_INST_1x(fucom, Fucom, St) // FPU + ASMJIT_INST_0x(fucom, Fucom) // FPU + ASMJIT_INST_1x(fucomi, Fucomi, St) // FPU + ASMJIT_INST_1x(fucomip, Fucomip, St) // FPU + ASMJIT_INST_1x(fucomp, Fucomp, St) // FPU + ASMJIT_INST_0x(fucomp, Fucomp) // FPU + ASMJIT_INST_0x(fucompp, Fucompp) // FPU + ASMJIT_INST_0x(fwait, Fwait) // FPU + ASMJIT_INST_0x(fxam, Fxam) // FPU + ASMJIT_INST_1x(fxch, Fxch, St) // FPU + ASMJIT_INST_0x(fxtract, Fxtract) // FPU + ASMJIT_INST_0x(fyl2x, Fyl2x) // FPU + ASMJIT_INST_0x(fyl2xp1, Fyl2xp1) // FPU + ASMJIT_INST_1x(fstsw, Fstsw, Gp) // FPU + ASMJIT_INST_1x(fstsw, Fstsw, Mem) // FPU + ASMJIT_INST_1x(fnstsw, Fnstsw, Gp) // FPU + ASMJIT_INST_1x(fnstsw, Fnstsw, Mem) // FPU + + //! \} + + //! \name MMX & SSE+ Instructions + //! \{ + + ASMJIT_INST_2x(addpd, Addpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(addpd, Addpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(addps, Addps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(addps, Addps, Xmm, Mem) // SSE + ASMJIT_INST_2x(addsd, Addsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(addsd, Addsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(addss, Addss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(addss, Addss, Xmm, Mem) // SSE + ASMJIT_INST_2x(addsubpd, Addsubpd, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(addsubpd, Addsubpd, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(addsubps, Addsubps, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(addsubps, Addsubps, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(andnpd, Andnpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(andnpd, Andnpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(andnps, Andnps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(andnps, Andnps, Xmm, Mem) // SSE + ASMJIT_INST_2x(andpd, Andpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(andpd, Andpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(andps, Andps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(andps, Andps, Xmm, Mem) // SSE + ASMJIT_INST_3x(blendpd, Blendpd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(blendpd, Blendpd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(blendps, Blendps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(blendps, Blendps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(blendvpd, Blendvpd, Xmm, Xmm, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(blendvpd, Blendvpd, Xmm, Mem, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(blendvps, Blendvps, Xmm, Xmm, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(blendvps, Blendvps, Xmm, Mem, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(cmppd, Cmppd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(cmppd, Cmppd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(cmpps, Cmpps, Xmm, Xmm, Imm) // SSE + ASMJIT_INST_3x(cmpps, Cmpps, Xmm, Mem, Imm) // SSE + ASMJIT_INST_3x(cmpsd, Cmpsd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(cmpsd, Cmpsd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(cmpss, Cmpss, Xmm, Xmm, Imm) // SSE + ASMJIT_INST_3x(cmpss, Cmpss, Xmm, Mem, Imm) // SSE + ASMJIT_INST_2x(comisd, Comisd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(comisd, Comisd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(comiss, Comiss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(comiss, Comiss, Xmm, Mem) // SSE + ASMJIT_INST_2x(cvtdq2pd, Cvtdq2pd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtdq2pd, Cvtdq2pd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtdq2ps, Cvtdq2ps, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtdq2ps, Cvtdq2ps, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpd2dq, Cvtpd2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtpd2dq, Cvtpd2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpd2pi, Cvtpd2pi, Mm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtpd2pi, Cvtpd2pi, Mm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpd2ps, Cvtpd2ps, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtpd2ps, Cvtpd2ps, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpi2pd, Cvtpi2pd, Xmm, Mm) // SSE2 + ASMJIT_INST_2x(cvtpi2pd, Cvtpi2pd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpi2ps, Cvtpi2ps, Xmm, Mm) // SSE + ASMJIT_INST_2x(cvtpi2ps, Cvtpi2ps, Xmm, Mem) // SSE + ASMJIT_INST_2x(cvtps2dq, Cvtps2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtps2dq, Cvtps2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtps2pd, Cvtps2pd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtps2pd, Cvtps2pd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtps2pi, Cvtps2pi, Mm, Xmm) // SSE + ASMJIT_INST_2x(cvtps2pi, Cvtps2pi, Mm, Mem) // SSE + ASMJIT_INST_2x(cvtsd2si, Cvtsd2si, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(cvtsd2si, Cvtsd2si, Gp, Mem) // SSE2 + ASMJIT_INST_2x(cvtsd2ss, Cvtsd2ss, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtsd2ss, Cvtsd2ss, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtsi2sd, Cvtsi2sd, Xmm, Gp) // SSE2 + ASMJIT_INST_2x(cvtsi2sd, Cvtsi2sd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtsi2ss, Cvtsi2ss, Xmm, Gp) // SSE + ASMJIT_INST_2x(cvtsi2ss, Cvtsi2ss, Xmm, Mem) // SSE + ASMJIT_INST_2x(cvtss2sd, Cvtss2sd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtss2sd, Cvtss2sd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtss2si, Cvtss2si, Gp, Xmm) // SSE + ASMJIT_INST_2x(cvtss2si, Cvtss2si, Gp, Mem) // SSE + ASMJIT_INST_2x(cvttpd2pi, Cvttpd2pi, Mm, Xmm) // SSE2 + ASMJIT_INST_2x(cvttpd2pi, Cvttpd2pi, Mm, Mem) // SSE2 + ASMJIT_INST_2x(cvttpd2dq, Cvttpd2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvttpd2dq, Cvttpd2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvttps2dq, Cvttps2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvttps2dq, Cvttps2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvttps2pi, Cvttps2pi, Mm, Xmm) // SSE + ASMJIT_INST_2x(cvttps2pi, Cvttps2pi, Mm, Mem) // SSE + ASMJIT_INST_2x(cvttsd2si, Cvttsd2si, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(cvttsd2si, Cvttsd2si, Gp, Mem) // SSE2 + ASMJIT_INST_2x(cvttss2si, Cvttss2si, Gp, Xmm) // SSE + ASMJIT_INST_2x(cvttss2si, Cvttss2si, Gp, Mem) // SSE + ASMJIT_INST_2x(divpd, Divpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(divpd, Divpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(divps, Divps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(divps, Divps, Xmm, Mem) // SSE + ASMJIT_INST_2x(divsd, Divsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(divsd, Divsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(divss, Divss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(divss, Divss, Xmm, Mem) // SSE + ASMJIT_INST_3x(dppd, Dppd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(dppd, Dppd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(dpps, Dpps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(dpps, Dpps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(extractps, Extractps, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(extractps, Extractps, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_2x(extrq, Extrq, Xmm, Xmm) // SSE4A + ASMJIT_INST_3x(extrq, Extrq, Xmm, Imm, Imm) // SSE4A + ASMJIT_INST_2x(haddpd, Haddpd, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(haddpd, Haddpd, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(haddps, Haddps, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(haddps, Haddps, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(hsubpd, Hsubpd, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(hsubpd, Hsubpd, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(hsubps, Hsubps, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(hsubps, Hsubps, Xmm, Mem) // SSE3 + ASMJIT_INST_3x(insertps, Insertps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(insertps, Insertps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_2x(insertq, Insertq, Xmm, Xmm) // SSE4A + ASMJIT_INST_4x(insertq, Insertq, Xmm, Xmm, Imm, Imm) // SSE4A + ASMJIT_INST_2x(lddqu, Lddqu, Xmm, Mem) // SSE3 + ASMJIT_INST_3x(maskmovq, Maskmovq, Mm, Mm, DS_ZDI) // SSE [EXPLICIT] + ASMJIT_INST_3x(maskmovdqu, Maskmovdqu, Xmm, Xmm, DS_ZDI) // SSE2 [EXPLICIT] + ASMJIT_INST_2x(maxpd, Maxpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(maxpd, Maxpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(maxps, Maxps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(maxps, Maxps, Xmm, Mem) // SSE + ASMJIT_INST_2x(maxsd, Maxsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(maxsd, Maxsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(maxss, Maxss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(maxss, Maxss, Xmm, Mem) // SSE + ASMJIT_INST_2x(minpd, Minpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(minpd, Minpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(minps, Minps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(minps, Minps, Xmm, Mem) // SSE + ASMJIT_INST_2x(minsd, Minsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(minsd, Minsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(minss, Minss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(minss, Minss, Xmm, Mem) // SSE + ASMJIT_INST_2x(movapd, Movapd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movapd, Movapd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movapd, Movapd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movaps, Movaps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movaps, Movaps, Xmm, Mem) // SSE + ASMJIT_INST_2x(movaps, Movaps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movd, Movd, Mem, Mm) // MMX + ASMJIT_INST_2x(movd, Movd, Mem, Xmm) // SSE + ASMJIT_INST_2x(movd, Movd, Gp, Mm) // MMX + ASMJIT_INST_2x(movd, Movd, Gp, Xmm) // SSE + ASMJIT_INST_2x(movd, Movd, Mm, Mem) // MMX + ASMJIT_INST_2x(movd, Movd, Xmm, Mem) // SSE + ASMJIT_INST_2x(movd, Movd, Mm, Gp) // MMX + ASMJIT_INST_2x(movd, Movd, Xmm, Gp) // SSE + ASMJIT_INST_2x(movddup, Movddup, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(movddup, Movddup, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(movdq2q, Movdq2q, Mm, Xmm) // SSE2 + ASMJIT_INST_2x(movdqa, Movdqa, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movdqa, Movdqa, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movdqa, Movdqa, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movdqu, Movdqu, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movdqu, Movdqu, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movdqu, Movdqu, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movhlps, Movhlps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movhpd, Movhpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movhpd, Movhpd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movhps, Movhps, Xmm, Mem) // SSE + ASMJIT_INST_2x(movhps, Movhps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movlhps, Movlhps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movlpd, Movlpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movlpd, Movlpd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movlps, Movlps, Xmm, Mem) // SSE + ASMJIT_INST_2x(movlps, Movlps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movmskps, Movmskps, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(movmskpd, Movmskpd, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(movntdq, Movntdq, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movntdqa, Movntdqa, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(movntpd, Movntpd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movntps, Movntps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movntsd, Movntsd, Mem, Xmm) // SSE4A + ASMJIT_INST_2x(movntss, Movntss, Mem, Xmm) // SSE4A + ASMJIT_INST_2x(movntq, Movntq, Mem, Mm) // SSE + ASMJIT_INST_2x(movq, Movq, Mm, Mm) // MMX + ASMJIT_INST_2x(movq, Movq, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movq, Movq, Mem, Mm) // MMX + ASMJIT_INST_2x(movq, Movq, Mem, Xmm) // SSE + ASMJIT_INST_2x(movq, Movq, Mm, Mem) // MMX + ASMJIT_INST_2x(movq, Movq, Xmm, Mem) // SSE + ASMJIT_INST_2x(movq, Movq, Gp, Mm) // MMX + ASMJIT_INST_2x(movq, Movq, Gp, Xmm) // SSE+X64. + ASMJIT_INST_2x(movq, Movq, Mm, Gp) // MMX + ASMJIT_INST_2x(movq, Movq, Xmm, Gp) // SSE+X64. + ASMJIT_INST_2x(movq2dq, Movq2dq, Xmm, Mm) // SSE2 + ASMJIT_INST_2x(movsd, Movsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movsd, Movsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movsd, Movsd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movshdup, Movshdup, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(movshdup, Movshdup, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(movsldup, Movsldup, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(movsldup, Movsldup, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(movss, Movss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movss, Movss, Xmm, Mem) // SSE + ASMJIT_INST_2x(movss, Movss, Mem, Xmm) // SSE + ASMJIT_INST_2x(movupd, Movupd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movupd, Movupd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movupd, Movupd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movups, Movups, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movups, Movups, Xmm, Mem) // SSE + ASMJIT_INST_2x(movups, Movups, Mem, Xmm) // SSE + ASMJIT_INST_3x(mpsadbw, Mpsadbw, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(mpsadbw, Mpsadbw, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_2x(mulpd, Mulpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(mulpd, Mulpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(mulps, Mulps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(mulps, Mulps, Xmm, Mem) // SSE + ASMJIT_INST_2x(mulsd, Mulsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(mulsd, Mulsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(mulss, Mulss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(mulss, Mulss, Xmm, Mem) // SSE + ASMJIT_INST_2x(orpd, Orpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(orpd, Orpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(orps, Orps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(orps, Orps, Xmm, Mem) // SSE + ASMJIT_INST_2x(packssdw, Packssdw, Mm, Mm) // MMX + ASMJIT_INST_2x(packssdw, Packssdw, Mm, Mem) // MMX + ASMJIT_INST_2x(packssdw, Packssdw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(packssdw, Packssdw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(packsswb, Packsswb, Mm, Mm) // MMX + ASMJIT_INST_2x(packsswb, Packsswb, Mm, Mem) // MMX + ASMJIT_INST_2x(packsswb, Packsswb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(packsswb, Packsswb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(packusdw, Packusdw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(packusdw, Packusdw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(packuswb, Packuswb, Mm, Mm) // MMX + ASMJIT_INST_2x(packuswb, Packuswb, Mm, Mem) // MMX + ASMJIT_INST_2x(packuswb, Packuswb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(packuswb, Packuswb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pabsb, Pabsb, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pabsb, Pabsb, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsb, Pabsb, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pabsb, Pabsb, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(paddb, Paddb, Mm, Mm) // MMX + ASMJIT_INST_2x(paddb, Paddb, Mm, Mem) // MMX + ASMJIT_INST_2x(paddb, Paddb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddb, Paddb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddd, Paddd, Mm, Mm) // MMX + ASMJIT_INST_2x(paddd, Paddd, Mm, Mem) // MMX + ASMJIT_INST_2x(paddd, Paddd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddd, Paddd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Mm, Mm) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Mm, Mem) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddsb, Paddsb, Mm, Mm) // MMX + ASMJIT_INST_2x(paddsb, Paddsb, Mm, Mem) // MMX + ASMJIT_INST_2x(paddsb, Paddsb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddsb, Paddsb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddsw, Paddsw, Mm, Mm) // MMX + ASMJIT_INST_2x(paddsw, Paddsw, Mm, Mem) // MMX + ASMJIT_INST_2x(paddsw, Paddsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddsw, Paddsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddusb, Paddusb, Mm, Mm) // MMX + ASMJIT_INST_2x(paddusb, Paddusb, Mm, Mem) // MMX + ASMJIT_INST_2x(paddusb, Paddusb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddusb, Paddusb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddusw, Paddusw, Mm, Mm) // MMX + ASMJIT_INST_2x(paddusw, Paddusw, Mm, Mem) // MMX + ASMJIT_INST_2x(paddusw, Paddusw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddusw, Paddusw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddw, Paddw, Mm, Mm) // MMX + ASMJIT_INST_2x(paddw, Paddw, Mm, Mem) // MMX + ASMJIT_INST_2x(paddw, Paddw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddw, Paddw, Xmm, Mem) // SSE2 + ASMJIT_INST_3x(palignr, Palignr, Mm, Mm, Imm) // SSSE3 + ASMJIT_INST_3x(palignr, Palignr, Mm, Mem, Imm) // SSSE3 + ASMJIT_INST_3x(palignr, Palignr, Xmm, Xmm, Imm) // SSSE3 + ASMJIT_INST_3x(palignr, Palignr, Xmm, Mem, Imm) // SSSE3 + ASMJIT_INST_2x(pand, Pand, Mm, Mm) // MMX + ASMJIT_INST_2x(pand, Pand, Mm, Mem) // MMX + ASMJIT_INST_2x(pand, Pand, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pand, Pand, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pandn, Pandn, Mm, Mm) // MMX + ASMJIT_INST_2x(pandn, Pandn, Mm, Mem) // MMX + ASMJIT_INST_2x(pandn, Pandn, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pandn, Pandn, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pavgb, Pavgb, Mm, Mm) // SSE + ASMJIT_INST_2x(pavgb, Pavgb, Mm, Mem) // SSE + ASMJIT_INST_2x(pavgb, Pavgb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pavgb, Pavgb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pavgw, Pavgw, Mm, Mm) // SSE + ASMJIT_INST_2x(pavgw, Pavgw, Mm, Mem) // SSE + ASMJIT_INST_2x(pavgw, Pavgw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pavgw, Pavgw, Xmm, Mem) // SSE2 + ASMJIT_INST_3x(pblendvb, Pblendvb, Xmm, Xmm, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(pblendvb, Pblendvb, Xmm, Mem, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(pblendw, Pblendw, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pblendw, Pblendw, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pclmulqdq, Pclmulqdq, Xmm, Xmm, Imm) // PCLMULQDQ. + ASMJIT_INST_3x(pclmulqdq, Pclmulqdq, Xmm, Mem, Imm) // PCLMULQDQ. + ASMJIT_INST_6x(pcmpestri, Pcmpestri, Xmm, Xmm, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_6x(pcmpestri, Pcmpestri, Xmm, Mem, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_6x(pcmpestrm, Pcmpestrm, Xmm, Xmm, Imm, XMM0, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_6x(pcmpestrm, Pcmpestrm, Xmm, Mem, Imm, XMM0, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpeqq, Pcmpeqq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pcmpeqq, Pcmpeqq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpgtq, Pcmpgtq, Xmm, Xmm) // SSE4_2. + ASMJIT_INST_2x(pcmpgtq, Pcmpgtq, Xmm, Mem) // SSE4_2. + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Xmm, Mem) // SSE2 + ASMJIT_INST_4x(pcmpistri, Pcmpistri, Xmm, Xmm, Imm, Gp_ECX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_4x(pcmpistri, Pcmpistri, Xmm, Mem, Imm, Gp_ECX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_4x(pcmpistrm, Pcmpistrm, Xmm, Xmm, Imm, XMM0) // SSE4_2 [EXPLICIT] + ASMJIT_INST_4x(pcmpistrm, Pcmpistrm, Xmm, Mem, Imm, XMM0) // SSE4_2 [EXPLICIT] + ASMJIT_INST_3x(pextrb, Pextrb, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrb, Pextrb, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrd, Pextrd, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrd, Pextrd, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrq, Pextrq, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrq, Pextrq, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrw, Pextrw, Gp, Mm, Imm) // SSE + ASMJIT_INST_3x(pextrw, Pextrw, Gp, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pextrw, Pextrw, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_2x(phaddd, Phaddd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phaddd, Phaddd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddd, Phaddd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phaddd, Phaddd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phminposuw, Phminposuw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(phminposuw, Phminposuw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(phsubd, Phsubd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phsubd, Phsubd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubd, Phsubd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phsubd, Phsubd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Xmm, Mem) // SSSE3 + ASMJIT_INST_3x(pinsrb, Pinsrb, Xmm, Gp, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrb, Pinsrb, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrd, Pinsrd, Xmm, Gp, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrd, Pinsrd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrq, Pinsrq, Xmm, Gp, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrq, Pinsrq, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrw, Pinsrw, Mm, Gp, Imm) // SSE + ASMJIT_INST_3x(pinsrw, Pinsrw, Mm, Mem, Imm) // SSE + ASMJIT_INST_3x(pinsrw, Pinsrw, Xmm, Gp, Imm) // SSE2 + ASMJIT_INST_3x(pinsrw, Pinsrw, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Mm, Mm) // MMX + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Mm, Mem) // MMX + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmaxsb, Pmaxsb, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxsb, Pmaxsb, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmaxsd, Pmaxsd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxsd, Pmaxsd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Mm, Mm) // SSE + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Mm, Mem) // SSE + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmaxub, Pmaxub, Mm, Mm) // SSE + ASMJIT_INST_2x(pmaxub, Pmaxub, Mm, Mem) // SSE + ASMJIT_INST_2x(pmaxub, Pmaxub, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmaxub, Pmaxub, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmaxud, Pmaxud, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxud, Pmaxud, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmaxuw, Pmaxuw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxuw, Pmaxuw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminsb, Pminsb, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminsb, Pminsb, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminsd, Pminsd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminsd, Pminsd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminsw, Pminsw, Mm, Mm) // SSE + ASMJIT_INST_2x(pminsw, Pminsw, Mm, Mem) // SSE + ASMJIT_INST_2x(pminsw, Pminsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pminsw, Pminsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pminub, Pminub, Mm, Mm) // SSE + ASMJIT_INST_2x(pminub, Pminub, Mm, Mem) // SSE + ASMJIT_INST_2x(pminub, Pminub, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pminub, Pminub, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pminud, Pminud, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminud, Pminud, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminuw, Pminuw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminuw, Pminuw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovmskb, Pmovmskb, Gp, Mm) // SSE + ASMJIT_INST_2x(pmovmskb, Pmovmskb, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(pmovsxbd, Pmovsxbd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxbd, Pmovsxbd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxbq, Pmovsxbq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxbq, Pmovsxbq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxbw, Pmovsxbw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxbw, Pmovsxbw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxdq, Pmovsxdq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxdq, Pmovsxdq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxwd, Pmovsxwd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxwd, Pmovsxwd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxwq, Pmovsxwq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxwq, Pmovsxwq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxbd, Pmovzxbd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxbd, Pmovzxbd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxbq, Pmovzxbq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxbq, Pmovzxbq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxbw, Pmovzxbw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxbw, Pmovzxbw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxdq, Pmovzxdq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxdq, Pmovzxdq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxwd, Pmovzxwd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxwd, Pmovzxwd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxwq, Pmovzxwq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxwq, Pmovzxwq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmuldq, Pmuldq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmuldq, Pmuldq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pmulhw, Pmulhw, Mm, Mm) // MMX + ASMJIT_INST_2x(pmulhw, Pmulhw, Mm, Mem) // MMX + ASMJIT_INST_2x(pmulhw, Pmulhw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmulhw, Pmulhw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Mm, Mm) // SSE + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Mm, Mem) // SSE + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmulld, Pmulld, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmulld, Pmulld, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmullw, Pmullw, Mm, Mm) // MMX + ASMJIT_INST_2x(pmullw, Pmullw, Mm, Mem) // MMX + ASMJIT_INST_2x(pmullw, Pmullw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmullw, Pmullw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Mm, Mm) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Mm, Mem) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(por, Por, Mm, Mm) // MMX + ASMJIT_INST_2x(por, Por, Mm, Mem) // MMX + ASMJIT_INST_2x(por, Por, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(por, Por, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psadbw, Psadbw, Mm, Mm) // SSE + ASMJIT_INST_2x(psadbw, Psadbw, Mm, Mem) // SSE + ASMJIT_INST_2x(psadbw, Psadbw, Xmm, Xmm) // SSE + ASMJIT_INST_2x(psadbw, Psadbw, Xmm, Mem) // SSE + ASMJIT_INST_2x(pslld, Pslld, Mm, Mm) // MMX + ASMJIT_INST_2x(pslld, Pslld, Mm, Mem) // MMX + ASMJIT_INST_2x(pslld, Pslld, Mm, Imm) // MMX + ASMJIT_INST_2x(pslld, Pslld, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pslld, Pslld, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pslld, Pslld, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(pslldq, Pslldq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psllq, Psllq, Mm, Mm) // MMX + ASMJIT_INST_2x(psllq, Psllq, Mm, Mem) // MMX + ASMJIT_INST_2x(psllq, Psllq, Mm, Imm) // MMX + ASMJIT_INST_2x(psllq, Psllq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psllq, Psllq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psllq, Psllq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psllw, Psllw, Mm, Mm) // MMX + ASMJIT_INST_2x(psllw, Psllw, Mm, Mem) // MMX + ASMJIT_INST_2x(psllw, Psllw, Mm, Imm) // MMX + ASMJIT_INST_2x(psllw, Psllw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psllw, Psllw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psllw, Psllw, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrad, Psrad, Mm, Mm) // MMX + ASMJIT_INST_2x(psrad, Psrad, Mm, Mem) // MMX + ASMJIT_INST_2x(psrad, Psrad, Mm, Imm) // MMX + ASMJIT_INST_2x(psrad, Psrad, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrad, Psrad, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrad, Psrad, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psraw, Psraw, Mm, Mm) // MMX + ASMJIT_INST_2x(psraw, Psraw, Mm, Mem) // MMX + ASMJIT_INST_2x(psraw, Psraw, Mm, Imm) // MMX + ASMJIT_INST_2x(psraw, Psraw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psraw, Psraw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psraw, Psraw, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(pshufb, Pshufb, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pshufb, Pshufb, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pshufb, Pshufb, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pshufb, Pshufb, Xmm, Mem) // SSSE3 + ASMJIT_INST_3x(pshufd, Pshufd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pshufd, Pshufd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(pshufhw, Pshufhw, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pshufhw, Pshufhw, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(pshuflw, Pshuflw, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pshuflw, Pshuflw, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(pshufw, Pshufw, Mm, Mm, Imm) // SSE + ASMJIT_INST_3x(pshufw, Pshufw, Mm, Mem, Imm) // SSE + ASMJIT_INST_2x(psignb, Psignb, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(psignb, Psignb, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(psignb, Psignb, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(psignb, Psignb, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(psrld, Psrld, Mm, Mm) // MMX + ASMJIT_INST_2x(psrld, Psrld, Mm, Mem) // MMX + ASMJIT_INST_2x(psrld, Psrld, Mm, Imm) // MMX + ASMJIT_INST_2x(psrld, Psrld, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrld, Psrld, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrld, Psrld, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrldq, Psrldq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrlq, Psrlq, Mm, Mm) // MMX + ASMJIT_INST_2x(psrlq, Psrlq, Mm, Mem) // MMX + ASMJIT_INST_2x(psrlq, Psrlq, Mm, Imm) // MMX + ASMJIT_INST_2x(psrlq, Psrlq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrlq, Psrlq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrlq, Psrlq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrlw, Psrlw, Mm, Mm) // MMX + ASMJIT_INST_2x(psrlw, Psrlw, Mm, Mem) // MMX + ASMJIT_INST_2x(psrlw, Psrlw, Mm, Imm) // MMX + ASMJIT_INST_2x(psrlw, Psrlw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrlw, Psrlw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrlw, Psrlw, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psubb, Psubb, Mm, Mm) // MMX + ASMJIT_INST_2x(psubb, Psubb, Mm, Mem) // MMX + ASMJIT_INST_2x(psubb, Psubb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubb, Psubb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubd, Psubd, Mm, Mm) // MMX + ASMJIT_INST_2x(psubd, Psubd, Mm, Mem) // MMX + ASMJIT_INST_2x(psubd, Psubd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubd, Psubd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Mm, Mm) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Mm, Mem) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubsb, Psubsb, Mm, Mm) // MMX + ASMJIT_INST_2x(psubsb, Psubsb, Mm, Mem) // MMX + ASMJIT_INST_2x(psubsb, Psubsb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubsb, Psubsb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubsw, Psubsw, Mm, Mm) // MMX + ASMJIT_INST_2x(psubsw, Psubsw, Mm, Mem) // MMX + ASMJIT_INST_2x(psubsw, Psubsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubsw, Psubsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubusb, Psubusb, Mm, Mm) // MMX + ASMJIT_INST_2x(psubusb, Psubusb, Mm, Mem) // MMX + ASMJIT_INST_2x(psubusb, Psubusb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubusb, Psubusb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubusw, Psubusw, Mm, Mm) // MMX + ASMJIT_INST_2x(psubusw, Psubusw, Mm, Mem) // MMX + ASMJIT_INST_2x(psubusw, Psubusw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubusw, Psubusw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubw, Psubw, Mm, Mm) // MMX + ASMJIT_INST_2x(psubw, Psubw, Mm, Mem) // MMX + ASMJIT_INST_2x(psubw, Psubw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubw, Psubw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(ptest, Ptest, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(ptest, Ptest, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckhqdq, Punpckhqdq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhqdq, Punpckhqdq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Mm, Mm) // MMX + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Mm, Mem) // MMX + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckldq, Punpckldq, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckldq, Punpckldq, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckldq, Punpckldq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckldq, Punpckldq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpcklqdq, Punpcklqdq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpcklqdq, Punpcklqdq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Mm, Mm) // MMX + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Mm, Mem) // MMX + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pxor, Pxor, Mm, Mm) // MMX + ASMJIT_INST_2x(pxor, Pxor, Mm, Mem) // MMX + ASMJIT_INST_2x(pxor, Pxor, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pxor, Pxor, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(rcpps, Rcpps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rcpps, Rcpps, Xmm, Mem) // SSE + ASMJIT_INST_2x(rcpss, Rcpss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rcpss, Rcpss, Xmm, Mem) // SSE + ASMJIT_INST_3x(roundpd, Roundpd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundpd, Roundpd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(roundps, Roundps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundps, Roundps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(roundsd, Roundsd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundsd, Roundsd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(roundss, Roundss, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundss, Roundss, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_2x(rsqrtps, Rsqrtps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rsqrtps, Rsqrtps, Xmm, Mem) // SSE + ASMJIT_INST_2x(rsqrtss, Rsqrtss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rsqrtss, Rsqrtss, Xmm, Mem) // SSE + ASMJIT_INST_3x(shufpd, Shufpd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(shufpd, Shufpd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(shufps, Shufps, Xmm, Xmm, Imm) // SSE + ASMJIT_INST_3x(shufps, Shufps, Xmm, Mem, Imm) // SSE + ASMJIT_INST_2x(sqrtpd, Sqrtpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(sqrtpd, Sqrtpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(sqrtps, Sqrtps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(sqrtps, Sqrtps, Xmm, Mem) // SSE + ASMJIT_INST_2x(sqrtsd, Sqrtsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(sqrtsd, Sqrtsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(sqrtss, Sqrtss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(sqrtss, Sqrtss, Xmm, Mem) // SSE + ASMJIT_INST_2x(subpd, Subpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(subpd, Subpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(subps, Subps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(subps, Subps, Xmm, Mem) // SSE + ASMJIT_INST_2x(subsd, Subsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(subsd, Subsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(subss, Subss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(subss, Subss, Xmm, Mem) // SSE + ASMJIT_INST_2x(ucomisd, Ucomisd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(ucomisd, Ucomisd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(ucomiss, Ucomiss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(ucomiss, Ucomiss, Xmm, Mem) // SSE + ASMJIT_INST_2x(unpckhpd, Unpckhpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(unpckhpd, Unpckhpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(unpckhps, Unpckhps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(unpckhps, Unpckhps, Xmm, Mem) // SSE + ASMJIT_INST_2x(unpcklpd, Unpcklpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(unpcklpd, Unpcklpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(unpcklps, Unpcklps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(unpcklps, Unpcklps, Xmm, Mem) // SSE + ASMJIT_INST_2x(xorpd, Xorpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(xorpd, Xorpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(xorps, Xorps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(xorps, Xorps, Xmm, Mem) // SSE + + //! \} + + //! \name 3DNOW and GEODE Instructions (Deprecated) + //! \{ + + ASMJIT_INST_2x(pavgusb, Pavgusb, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pavgusb, Pavgusb, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pf2id, Pf2id, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pf2id, Pf2id, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pf2iw, Pf2iw, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pf2iw, Pf2iw, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfacc, Pfacc, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfacc, Pfacc, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfadd, Pfadd, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfadd, Pfadd, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfcmpeq, Pfcmpeq, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfcmpeq, Pfcmpeq, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfcmpge, Pfcmpge, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfcmpge, Pfcmpge, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfcmpgt, Pfcmpgt, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfcmpgt, Pfcmpgt, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfmax, Pfmax, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfmax, Pfmax, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfmin, Pfmin, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfmin, Pfmin, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfmul, Pfmul, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfmul, Pfmul, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfnacc, Pfnacc, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfnacc, Pfnacc, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfpnacc, Pfpnacc, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfpnacc, Pfpnacc, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcp, Pfrcp, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrcp, Pfrcp, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcpit1, Pfrcpit1, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrcpit1, Pfrcpit1, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcpit2, Pfrcpit2, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrcpit2, Pfrcpit2, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcpv, Pfrcpv, Mm, Mm) // GEODE + ASMJIT_INST_2x(pfrcpv, Pfrcpv, Mm, Mem) // GEODE + ASMJIT_INST_2x(pfrsqit1, Pfrsqit1, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrsqit1, Pfrsqit1, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrsqrt, Pfrsqrt, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrsqrt, Pfrsqrt, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrsqrtv, Pfrsqrtv, Mm, Mm) // GEODE + ASMJIT_INST_2x(pfrsqrtv, Pfrsqrtv, Mm, Mem) // GEODE + ASMJIT_INST_2x(pfsub, Pfsub, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfsub, Pfsub, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfsubr, Pfsubr, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfsubr, Pfsubr, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pi2fd, Pi2fd, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pi2fd, Pi2fd, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pi2fw, Pi2fw, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pi2fw, Pi2fw, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pmulhrw, Pmulhrw, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pmulhrw, Pmulhrw, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pswapd, Pswapd, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pswapd, Pswapd, Mm, Mem) // 3DNOW + + //! \} + + //! \name EMMS/FEMMS Instructions + //! \{ + + ASMJIT_INST_0x(emms, Emms) // MMX + ASMJIT_INST_0x(femms, Femms) // 3DNOW + + //! \} + + //! \name AESNI Instructions + //! \{ + + ASMJIT_INST_2x(aesdec, Aesdec, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesdec, Aesdec, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesdeclast, Aesdeclast, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesdeclast, Aesdeclast, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesenc, Aesenc, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesenc, Aesenc, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesenclast, Aesenclast, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesenclast, Aesenclast, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesimc, Aesimc, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesimc, Aesimc, Xmm, Mem) // AESNI + ASMJIT_INST_3x(aeskeygenassist, Aeskeygenassist, Xmm, Xmm, Imm) // AESNI + ASMJIT_INST_3x(aeskeygenassist, Aeskeygenassist, Xmm, Mem, Imm) // AESNI + + //! \} + + //! \name SHA Instructions + //! \{ + + ASMJIT_INST_2x(sha1msg1, Sha1msg1, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha1msg1, Sha1msg1, Xmm, Mem) // SHA + ASMJIT_INST_2x(sha1msg2, Sha1msg2, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha1msg2, Sha1msg2, Xmm, Mem) // SHA + ASMJIT_INST_2x(sha1nexte, Sha1nexte, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha1nexte, Sha1nexte, Xmm, Mem) // SHA + ASMJIT_INST_3x(sha1rnds4, Sha1rnds4, Xmm, Xmm, Imm) // SHA + ASMJIT_INST_3x(sha1rnds4, Sha1rnds4, Xmm, Mem, Imm) // SHA + ASMJIT_INST_2x(sha256msg1, Sha256msg1, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha256msg1, Sha256msg1, Xmm, Mem) // SHA + ASMJIT_INST_2x(sha256msg2, Sha256msg2, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha256msg2, Sha256msg2, Xmm, Mem) // SHA + ASMJIT_INST_3x(sha256rnds2, Sha256rnds2, Xmm, Xmm, XMM0) // SHA [EXPLICIT] + ASMJIT_INST_3x(sha256rnds2, Sha256rnds2, Xmm, Mem, XMM0) // SHA [EXPLICIT] + + //! \} + + //! \name GFNI Instructions + //! \{ + + ASMJIT_INST_3x(gf2p8affineinvqb, Gf2p8affineinvqb, Xmm, Xmm, Imm) // GFNI + ASMJIT_INST_3x(gf2p8affineinvqb, Gf2p8affineinvqb, Xmm, Mem, Imm) // GFNI + ASMJIT_INST_3x(gf2p8affineqb, Gf2p8affineqb, Xmm, Xmm, Imm) // GFNI + ASMJIT_INST_3x(gf2p8affineqb, Gf2p8affineqb, Xmm, Mem, Imm) // GFNI + ASMJIT_INST_2x(gf2p8mulb, Gf2p8mulb, Xmm, Xmm) // GFNI + ASMJIT_INST_2x(gf2p8mulb, Gf2p8mulb, Xmm, Mem) // GFNI + + //! \} + + //! \name AVX, FMA, and AVX512 Instructions + //! \{ + + ASMJIT_INST_3x(kaddb, Kaddb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kaddd, Kaddd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kaddq, Kaddq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kaddw, Kaddw, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kandb, Kandb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kandd, Kandd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandnb, Kandnb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kandnd, Kandnd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandnq, Kandnq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandnw, Kandnw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kandq, Kandq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandw, Kandw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_2x(kmovb, Kmovb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, KReg, Mem) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, KReg, Gp) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, Mem, KReg) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, Gp, KReg) // AVX512_DQ + ASMJIT_INST_2x(kmovd, Kmovd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, KReg, Mem) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, KReg, Gp) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, Mem, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, Gp, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, KReg, Mem) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, KReg, Gp) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, Mem, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, Gp, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovw, Kmovw, KReg, KReg) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, KReg, Mem) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, KReg, Gp) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, Mem, KReg) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, Gp, KReg) // AVX512_F + ASMJIT_INST_2x(knotb, Knotb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(knotd, Knotd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(knotq, Knotq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(knotw, Knotw, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(korb, Korb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kord, Kord, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(korq, Korq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kortestb, Kortestb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(kortestd, Kortestd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kortestq, Kortestq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kortestw, Kortestw, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(korw, Korw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kshiftlb, Kshiftlb, KReg, KReg, Imm) // AVX512_DQ + ASMJIT_INST_3x(kshiftld, Kshiftld, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftlq, Kshiftlq, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftlw, Kshiftlw, KReg, KReg, Imm) // AVX512_F + ASMJIT_INST_3x(kshiftrb, Kshiftrb, KReg, KReg, Imm) // AVX512_DQ + ASMJIT_INST_3x(kshiftrd, Kshiftrd, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftrq, Kshiftrq, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftrw, Kshiftrw, KReg, KReg, Imm) // AVX512_F + ASMJIT_INST_2x(ktestb, Ktestb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(ktestd, Ktestd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(ktestq, Ktestq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(ktestw, Ktestw, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kunpckbw, Kunpckbw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kunpckdq, Kunpckdq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kunpckwd, Kunpckwd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxnorb, Kxnorb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kxnord, Kxnord, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxnorq, Kxnorq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxnorw, Kxnorw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kxorb, Kxorb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kxord, Kxord, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxorq, Kxorq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxorw, Kxorw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_6x(v4fmaddps, V4fmaddps, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(v4fmaddss, V4fmaddss, Xmm, Xmm, Xmm, Xmm, Xmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(v4fnmaddps, V4fnmaddps, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(v4fnmaddss, V4fnmaddss, Xmm, Xmm, Xmm, Xmm, Xmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_3x(vaddpd, Vaddpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vaddpd, Vaddpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vaddps, Vaddps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vaddps, Vaddps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vaddsd, Vaddsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddsd, Vaddsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddss, Vaddss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddss, Vaddss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddsubpd, Vaddsubpd, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vaddsubpd, Vaddsubpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vaddsubps, Vaddsubps, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vaddsubps, Vaddsubps, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vaesdec, Vaesdec, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesdec, Vaesdec, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesdeclast, Vaesdeclast, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesdeclast, Vaesdeclast, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenc, Vaesenc, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenc, Vaesenc, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenclast, Vaesenclast, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenclast, Vaesenclast, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_2x(vaesimc, Vaesimc, Xmm, Xmm) // AVX+AESNI + ASMJIT_INST_2x(vaesimc, Vaesimc, Xmm, Mem) // AVX+AESNI + ASMJIT_INST_3x(vaeskeygenassist, Vaeskeygenassist, Xmm, Xmm, Imm) // AVX+AESNI + ASMJIT_INST_3x(vaeskeygenassist, Vaeskeygenassist, Xmm, Mem, Imm) // AVX+AESNI + ASMJIT_INST_4x(valignd, Valignd, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(valignd, Valignd, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(valignq, Valignq, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(valignq, Valignq, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vandnpd, Vandnpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandnpd, Vandnpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandnps, Vandnps, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vandnps, Vandnps, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vandpd, Vandpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandpd, Vandpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandps, Vandps, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vandps, Vandps, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vblendmpd, Vblendmpd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vblendmpd, Vblendmpd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vblendmps, Vblendmps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vblendmps, Vblendmps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vblendpd, Vblendpd, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vblendpd, Vblendpd, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vblendps, Vblendps, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vblendps, Vblendps, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vblendvpd, Vblendvpd, Vec, Vec, Vec, Vec) // AVX + ASMJIT_INST_4x(vblendvpd, Vblendvpd, Vec, Vec, Mem, Vec) // AVX + ASMJIT_INST_4x(vblendvps, Vblendvps, Vec, Vec, Vec, Vec) // AVX + ASMJIT_INST_4x(vblendvps, Vblendvps, Vec, Vec, Mem, Vec) // AVX + ASMJIT_INST_2x(vbroadcastf128, Vbroadcastf128, Vec, Mem) // AVX + ASMJIT_INST_2x(vbroadcastf32x2, Vbroadcastf32x2, Vec, Vec) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf32x2, Vbroadcastf32x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf32x4, Vbroadcastf32x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastf32x8, Vbroadcastf32x8, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf64x2, Vbroadcastf64x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf64x4, Vbroadcastf64x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcasti128, Vbroadcasti128, Vec, Mem) // AVX2 + ASMJIT_INST_2x(vbroadcasti32x2, Vbroadcasti32x2, Vec, Vec) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti32x2, Vbroadcasti32x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti32x4, Vbroadcasti32x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcasti32x8, Vbroadcasti32x8, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti64x2, Vbroadcasti64x2, Vec, Vec) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti64x2, Vbroadcasti64x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti64x4, Vbroadcasti64x4, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcasti64x4, Vbroadcasti64x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastsd, Vbroadcastsd, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastsd, Vbroadcastsd, Vec, Xmm) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastss, Vbroadcastss, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastss, Vbroadcastss, Vec, Xmm) // AVX2 AVX512_F{kz} + ASMJIT_INST_4x(vcmppd, Vcmppd, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vcmppd, Vcmppd, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmppd, Vcmppd, KReg, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vcmppd, Vcmppd, KReg, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vcmpps, Vcmpps, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vcmpps, Vcmpps, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmpps, Vcmpps, KReg, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vcmpps, Vcmpps, KReg, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vcmpsd, Vcmpsd, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vcmpsd, Vcmpsd, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmpsd, Vcmpsd, KReg, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vcmpsd, Vcmpsd, KReg, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vcmpss, Vcmpss, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vcmpss, Vcmpss, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmpss, Vcmpss, KReg, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vcmpss, Vcmpss, KReg, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_2x(vcomisd, Vcomisd, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcomisd, Vcomisd, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcomiss, Vcomiss, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcomiss, Vcomiss, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcompresspd, Vcompresspd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcompresspd, Vcompresspd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcompressps, Vcompressps, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcompressps, Vcompressps, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcvtdq2pd, Vcvtdq2pd, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtdq2pd, Vcvtdq2pd, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtdq2ps, Vcvtdq2ps, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtdq2ps, Vcvtdq2ps, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vcvtne2ps2bf16, Vcvtne2ps2bf16, Vec, Vec, Vec) // AVX512_BF16{kz|b32} + ASMJIT_INST_3x(vcvtne2ps2bf16, Vcvtne2ps2bf16, Vec, Vec, Mem) // AVX512_BF16{kz|b32} + ASMJIT_INST_2x(vcvtneps2bf16, Vcvtneps2bf16, Vec, Vec) // AVX512_BF16{kz|b32} + ASMJIT_INST_2x(vcvtneps2bf16, Vcvtneps2bf16, Vec, Mem) // AVX512_BF16{kz|b32} + ASMJIT_INST_2x(vcvtpd2dq, Vcvtpd2dq, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2dq, Vcvtpd2dq, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2ps, Vcvtpd2ps, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2ps, Vcvtpd2ps, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2qq, Vcvtpd2qq, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtpd2qq, Vcvtpd2qq, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtpd2udq, Vcvtpd2udq, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2udq, Vcvtpd2udq, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2uqq, Vcvtpd2uqq, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtpd2uqq, Vcvtpd2uqq, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtph2ps, Vcvtph2ps, Vec, Vec) // F16C AVX512_F{kz} + ASMJIT_INST_2x(vcvtph2ps, Vcvtph2ps, Vec, Mem) // F16C AVX512_F{kz} + ASMJIT_INST_2x(vcvtps2dq, Vcvtps2dq, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2dq, Vcvtps2dq, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2pd, Vcvtps2pd, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2pd, Vcvtps2pd, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vcvtps2ph, Vcvtps2ph, Vec, Vec, Imm) // F16C AVX512_F{kz} + ASMJIT_INST_3x(vcvtps2ph, Vcvtps2ph, Mem, Vec, Imm) // F16C AVX512_F{kz} + ASMJIT_INST_2x(vcvtps2qq, Vcvtps2qq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtps2qq, Vcvtps2qq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtps2udq, Vcvtps2udq, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2udq, Vcvtps2udq, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2uqq, Vcvtps2uqq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtps2uqq, Vcvtps2uqq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtqq2pd, Vcvtqq2pd, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtqq2pd, Vcvtqq2pd, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtqq2ps, Vcvtqq2ps, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtqq2ps, Vcvtqq2ps, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtsd2si, Vcvtsd2si, Gp, Xmm) // AVX AVX512_F{er} + ASMJIT_INST_2x(vcvtsd2si, Vcvtsd2si, Gp, Mem) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsd2ss, Vcvtsd2ss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vcvtsd2ss, Vcvtsd2ss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_2x(vcvtsd2usi, Vcvtsd2usi, Gp, Xmm) // AVX512_F{er} + ASMJIT_INST_2x(vcvtsd2usi, Vcvtsd2usi, Gp, Mem) // AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2sd, Vcvtsi2sd, Xmm, Xmm, Gp) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2sd, Vcvtsi2sd, Xmm, Xmm, Mem) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2ss, Vcvtsi2ss, Xmm, Xmm, Gp) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2ss, Vcvtsi2ss, Xmm, Xmm, Mem) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtss2sd, Vcvtss2sd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vcvtss2sd, Vcvtss2sd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_2x(vcvtss2si, Vcvtss2si, Gp, Xmm) // AVX AVX512_F{er} + ASMJIT_INST_2x(vcvtss2si, Vcvtss2si, Gp, Mem) // AVX AVX512_F{er} + ASMJIT_INST_2x(vcvtss2usi, Vcvtss2usi, Gp, Xmm) // AVX512_F{er} + ASMJIT_INST_2x(vcvtss2usi, Vcvtss2usi, Gp, Mem) // AVX512_F{er} + ASMJIT_INST_2x(vcvttpd2dq, Vcvttpd2dq, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2dq, Vcvttpd2dq, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2qq, Vcvttpd2qq, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2qq, Vcvttpd2qq, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2udq, Vcvttpd2udq, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2udq, Vcvttpd2udq, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2uqq, Vcvttpd2uqq, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvttpd2uqq, Vcvttpd2uqq, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvttps2dq, Vcvttps2dq, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2dq, Vcvttps2dq, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2qq, Vcvttps2qq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttps2qq, Vcvttps2qq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttps2udq, Vcvttps2udq, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2udq, Vcvttps2udq, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2uqq, Vcvttps2uqq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttps2uqq, Vcvttps2uqq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttsd2si, Vcvttsd2si, Gp, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttsd2si, Vcvttsd2si, Gp, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttsd2usi, Vcvttsd2usi, Gp, Xmm) // AVX512_F{sae} + ASMJIT_INST_2x(vcvttsd2usi, Vcvttsd2usi, Gp, Mem) // AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2si, Vcvttss2si, Gp, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2si, Vcvttss2si, Gp, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2usi, Vcvttss2usi, Gp, Xmm) // AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2usi, Vcvttss2usi, Gp, Mem) // AVX512_F{sae} + ASMJIT_INST_2x(vcvtudq2pd, Vcvtudq2pd, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtudq2pd, Vcvtudq2pd, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtudq2ps, Vcvtudq2ps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtudq2ps, Vcvtudq2ps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtuqq2pd, Vcvtuqq2pd, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtuqq2pd, Vcvtuqq2pd, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtuqq2ps, Vcvtuqq2ps, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtuqq2ps, Vcvtuqq2ps, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vcvtusi2sd, Vcvtusi2sd, Xmm, Xmm, Gp) // AVX512_F{er} + ASMJIT_INST_3x(vcvtusi2sd, Vcvtusi2sd, Xmm, Xmm, Mem) // AVX512_F{er} + ASMJIT_INST_3x(vcvtusi2ss, Vcvtusi2ss, Xmm, Xmm, Gp) // AVX512_F{er} + ASMJIT_INST_3x(vcvtusi2ss, Vcvtusi2ss, Xmm, Xmm, Mem) // AVX512_F{er} + ASMJIT_INST_4x(vdbpsadbw, Vdbpsadbw, Vec, Vec, Vec, Imm) // AVX512_BW{kz} + ASMJIT_INST_4x(vdbpsadbw, Vdbpsadbw, Vec, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vdivpd, Vdivpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vdivpd, Vdivpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vdivps, Vdivps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vdivps, Vdivps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vdivsd, Vdivsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdivsd, Vdivsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdivss, Vdivss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdivss, Vdivss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdpbf16ps, Vdpbf16ps, Vec, Vec, Vec) // AVX512_BF16{kz|b32} + ASMJIT_INST_3x(vdpbf16ps, Vdpbf16ps, Vec, Vec, Mem) // AVX512_BF16{kz|b32} + ASMJIT_INST_4x(vdppd, Vdppd, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vdppd, Vdppd, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vdpps, Vdpps, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vdpps, Vdpps, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_2x(vexp2pd, Vexp2pd, Vec, Vec) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vexp2pd, Vexp2pd, Vec, Mem) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vexp2ps, Vexp2ps, Vec, Vec) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vexp2ps, Vexp2ps, Vec, Mem) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vexpandpd, Vexpandpd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vexpandpd, Vexpandpd, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vexpandps, Vexpandps, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vexpandps, Vexpandps, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf128, Vextractf128, Vec, Vec, Imm) // AVX + ASMJIT_INST_3x(vextractf128, Vextractf128, Mem, Vec, Imm) // AVX + ASMJIT_INST_3x(vextractf32x4, Vextractf32x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf32x4, Vextractf32x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf32x8, Vextractf32x8, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf32x8, Vextractf32x8, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf64x2, Vextractf64x2, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf64x2, Vextractf64x2, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf64x4, Vextractf64x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf64x4, Vextractf64x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti128, Vextracti128, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_3x(vextracti128, Vextracti128, Mem, Vec, Imm) // AVX2 + ASMJIT_INST_3x(vextracti32x4, Vextracti32x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti32x4, Vextracti32x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti32x8, Vextracti32x8, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti32x8, Vextracti32x8, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti64x2, Vextracti64x2, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti64x2, Vextracti64x2, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti64x4, Vextracti64x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti64x4, Vextracti64x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractps, Vextractps, Gp, Xmm, Imm) // AVX AVX512_F + ASMJIT_INST_3x(vextractps, Vextractps, Mem, Xmm, Imm) // AVX AVX512_F + ASMJIT_INST_4x(vfixupimmpd, Vfixupimmpd, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vfixupimmpd, Vfixupimmpd, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vfixupimmps, Vfixupimmps, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vfixupimmps, Vfixupimmps, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vfixupimmsd, Vfixupimmsd, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vfixupimmsd, Vfixupimmsd, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vfixupimmss, Vfixupimmss, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vfixupimmss, Vfixupimmss, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vfmadd132pd, Vfmadd132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd132pd, Vfmadd132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd132ps, Vfmadd132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd132ps, Vfmadd132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd132sd, Vfmadd132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd132sd, Vfmadd132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd132ss, Vfmadd132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd132ss, Vfmadd132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213pd, Vfmadd213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd213pd, Vfmadd213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd213ps, Vfmadd213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd213ps, Vfmadd213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd213sd, Vfmadd213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213sd, Vfmadd213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213ss, Vfmadd213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213ss, Vfmadd213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231pd, Vfmadd231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd231pd, Vfmadd231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd231ps, Vfmadd231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd231ps, Vfmadd231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd231sd, Vfmadd231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231sd, Vfmadd231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231ss, Vfmadd231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231ss, Vfmadd231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmaddsub132pd, Vfmaddsub132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub132pd, Vfmaddsub132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub132ps, Vfmaddsub132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub132ps, Vfmaddsub132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub213pd, Vfmaddsub213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub213pd, Vfmaddsub213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub213ps, Vfmaddsub213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub213ps, Vfmaddsub213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub231pd, Vfmaddsub231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub231pd, Vfmaddsub231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub231ps, Vfmaddsub231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub231ps, Vfmaddsub231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub132pd, Vfmsub132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub132pd, Vfmsub132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub132ps, Vfmsub132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub132ps, Vfmsub132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub132sd, Vfmsub132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub132sd, Vfmsub132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub132ss, Vfmsub132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub132ss, Vfmsub132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213pd, Vfmsub213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub213pd, Vfmsub213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub213ps, Vfmsub213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub213ps, Vfmsub213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub213sd, Vfmsub213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213sd, Vfmsub213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213ss, Vfmsub213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213ss, Vfmsub213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231pd, Vfmsub231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub231pd, Vfmsub231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub231ps, Vfmsub231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub231ps, Vfmsub231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub231sd, Vfmsub231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231sd, Vfmsub231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231ss, Vfmsub231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231ss, Vfmsub231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsubadd132pd, Vfmsubadd132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd132pd, Vfmsubadd132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd132ps, Vfmsubadd132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd132ps, Vfmsubadd132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd213pd, Vfmsubadd213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd213pd, Vfmsubadd213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd213ps, Vfmsubadd213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd213ps, Vfmsubadd213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd231pd, Vfmsubadd231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd231pd, Vfmsubadd231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd231ps, Vfmsubadd231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd231ps, Vfmsubadd231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd132pd, Vfnmadd132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd132pd, Vfnmadd132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd132ps, Vfnmadd132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd132ps, Vfnmadd132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd132sd, Vfnmadd132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd132sd, Vfnmadd132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd132ss, Vfnmadd132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd132ss, Vfnmadd132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213pd, Vfnmadd213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd213pd, Vfnmadd213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd213ps, Vfnmadd213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd213ps, Vfnmadd213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd213sd, Vfnmadd213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213sd, Vfnmadd213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213ss, Vfnmadd213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213ss, Vfnmadd213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231pd, Vfnmadd231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd231pd, Vfnmadd231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd231ps, Vfnmadd231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd231ps, Vfnmadd231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd231sd, Vfnmadd231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231sd, Vfnmadd231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231ss, Vfnmadd231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231ss, Vfnmadd231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132pd, Vfnmsub132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub132pd, Vfnmsub132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub132ps, Vfnmsub132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub132ps, Vfnmsub132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub132sd, Vfnmsub132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132sd, Vfnmsub132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132ss, Vfnmsub132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132ss, Vfnmsub132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213pd, Vfnmsub213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub213pd, Vfnmsub213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub213ps, Vfnmsub213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub213ps, Vfnmsub213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub213sd, Vfnmsub213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213sd, Vfnmsub213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213ss, Vfnmsub213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213ss, Vfnmsub213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231pd, Vfnmsub231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub231pd, Vfnmsub231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub231ps, Vfnmsub231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub231ps, Vfnmsub231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub231sd, Vfnmsub231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231sd, Vfnmsub231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231ss, Vfnmsub231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231ss, Vfnmsub231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfpclasspd, Vfpclasspd, KReg, Vec, Imm) // AVX512_DQ{k|b64} + ASMJIT_INST_3x(vfpclasspd, Vfpclasspd, KReg, Mem, Imm) // AVX512_DQ{k|b64} + ASMJIT_INST_3x(vfpclassps, Vfpclassps, KReg, Vec, Imm) // AVX512_DQ{k|b32} + ASMJIT_INST_3x(vfpclassps, Vfpclassps, KReg, Mem, Imm) // AVX512_DQ{k|b32} + ASMJIT_INST_3x(vfpclasssd, Vfpclasssd, KReg, Xmm, Imm) // AVX512_DQ{k} + ASMJIT_INST_3x(vfpclasssd, Vfpclasssd, KReg, Mem, Imm) // AVX512_DQ{k} + ASMJIT_INST_3x(vfpclassss, Vfpclassss, KReg, Xmm, Imm) // AVX512_DQ{k} + ASMJIT_INST_3x(vfpclassss, Vfpclassss, KReg, Mem, Imm) // AVX512_DQ{k} + ASMJIT_INST_2x(vgatherdpd, Vgatherdpd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherdpd, Vgatherdpd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vgatherdps, Vgatherdps, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherdps, Vgatherdps, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_1x(vgatherpf0dpd, Vgatherpf0dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf0dps, Vgatherpf0dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf0qpd, Vgatherpf0qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf0qps, Vgatherpf0qps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1dpd, Vgatherpf1dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1dps, Vgatherpf1dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1qpd, Vgatherpf1qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1qps, Vgatherpf1qps, Mem) // AVX512_PF{k} + ASMJIT_INST_2x(vgatherqpd, Vgatherqpd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherqpd, Vgatherqpd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vgatherqps, Vgatherqps, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherqps, Vgatherqps, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vgetexppd, Vgetexppd, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vgetexppd, Vgetexppd, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vgetexpps, Vgetexpps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vgetexpps, Vgetexpps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vgetexpsd, Vgetexpsd, Xmm, Xmm, Xmm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetexpsd, Vgetexpsd, Xmm, Xmm, Mem) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetexpss, Vgetexpss, Xmm, Xmm, Xmm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetexpss, Vgetexpss, Xmm, Xmm, Mem) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetmantpd, Vgetmantpd, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vgetmantpd, Vgetmantpd, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vgetmantps, Vgetmantps, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vgetmantps, Vgetmantps, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vgetmantsd, Vgetmantsd, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgetmantsd, Vgetmantsd, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgetmantss, Vgetmantss, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgetmantss, Vgetmantss, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgf2p8affineinvqb, Vgf2p8affineinvqb,Vec,Vec,Vec,Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_4x(vgf2p8affineinvqb, Vgf2p8affineinvqb,Vec,Vec,Mem,Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_4x(vgf2p8affineqb, Vgf2p8affineqb, Vec, Vec, Vec, Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_4x(vgf2p8affineqb, Vgf2p8affineqb, Vec, Vec, Mem, Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_3x(vgf2p8mulb, Vgf2p8mulb, Vec, Vec, Vec) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_3x(vgf2p8mulb, Vgf2p8mulb, Vec, Vec, Mem) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_3x(vhaddpd, Vhaddpd, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhaddpd, Vhaddpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vhaddps, Vhaddps, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhaddps, Vhaddps, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vhsubpd, Vhsubpd, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhsubpd, Vhsubpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vhsubps, Vhsubps, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhsubps, Vhsubps, Vec, Vec, Mem) // AVX + ASMJIT_INST_4x(vinsertf128, Vinsertf128, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vinsertf128, Vinsertf128, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vinsertf32x4, Vinsertf32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertf32x4, Vinsertf32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertf32x8, Vinsertf32x8, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf32x8, Vinsertf32x8, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf64x2, Vinsertf64x2, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf64x2, Vinsertf64x2, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf64x4, Vinsertf64x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertf64x4, Vinsertf64x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti128, Vinserti128, Vec, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_4x(vinserti128, Vinserti128, Vec, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_4x(vinserti32x4, Vinserti32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti32x4, Vinserti32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti32x8, Vinserti32x8, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti32x8, Vinserti32x8, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti64x2, Vinserti64x2, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti64x2, Vinserti64x2, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti64x4, Vinserti64x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti64x4, Vinserti64x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertps, Vinsertps, Xmm, Xmm, Xmm, Imm) // AVX AVX512_F + ASMJIT_INST_4x(vinsertps, Vinsertps, Xmm, Xmm, Mem, Imm) // AVX AVX512_F + ASMJIT_INST_2x(vlddqu, Vlddqu, Vec, Mem) // AVX + ASMJIT_INST_1x(vldmxcsr, Vldmxcsr, Mem) // AVX + ASMJIT_INST_3x(vmaskmovdqu, Vmaskmovdqu, Vec, Vec, DS_ZDI) // AVX [EXPLICIT] + ASMJIT_INST_3x(vmaskmovpd, Vmaskmovpd, Mem, Vec, Vec) // AVX + ASMJIT_INST_3x(vmaskmovpd, Vmaskmovpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vmaskmovps, Vmaskmovps, Mem, Vec, Vec) // AVX + ASMJIT_INST_3x(vmaskmovps, Vmaskmovps, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vmaxpd, Vmaxpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmaxpd, Vmaxpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmaxps, Vmaxps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmaxps, Vmaxps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmaxsd, Vmaxsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vmaxsd, Vmaxsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vmaxss, Vmaxss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vmaxss, Vmaxss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminpd, Vminpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vminpd, Vminpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vminps, Vminps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vminps, Vminps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vminsd, Vminsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminsd, Vminsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminss, Vminss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminss, Vminss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_2x(vmovapd, Vmovapd, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovapd, Vmovapd, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovapd, Vmovapd, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovaps, Vmovaps, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovaps, Vmovaps, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovaps, Vmovaps, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovd, Vmovd, Gp, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovd, Vmovd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovd, Vmovd, Xmm, Gp) // AVX AVX512_F + ASMJIT_INST_2x(vmovd, Vmovd, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovddup, Vmovddup, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovddup, Vmovddup, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa, Vmovdqa, Vec, Vec) // AVX + ASMJIT_INST_2x(vmovdqa, Vmovdqa, Vec, Mem) // AVX + ASMJIT_INST_2x(vmovdqa, Vmovdqa, Mem, Vec) // AVX + ASMJIT_INST_2x(vmovdqa32, Vmovdqa32, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa32, Vmovdqa32, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa32, Vmovdqa32, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa64, Vmovdqa64, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa64, Vmovdqa64, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa64, Vmovdqa64, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu, Vmovdqu, Vec, Vec) // AVX + ASMJIT_INST_2x(vmovdqu, Vmovdqu, Vec, Mem) // AVX + ASMJIT_INST_2x(vmovdqu, Vmovdqu, Mem, Vec) // AVX + ASMJIT_INST_2x(vmovdqu16, Vmovdqu16, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu16, Vmovdqu16, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu16, Vmovdqu16, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu32, Vmovdqu32, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu32, Vmovdqu32, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu32, Vmovdqu32, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu64, Vmovdqu64, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu64, Vmovdqu64, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu64, Vmovdqu64, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu8, Vmovdqu8, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu8, Vmovdqu8, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu8, Vmovdqu8, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vmovhlps, Vmovhlps, Xmm, Xmm, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovhpd, Vmovhpd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovhpd, Vmovhpd, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovhps, Vmovhps, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovhps, Vmovhps, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_3x(vmovlhps, Vmovlhps, Xmm, Xmm, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovlpd, Vmovlpd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovlpd, Vmovlpd, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovlps, Vmovlps, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovlps, Vmovlps, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovmskpd, Vmovmskpd, Gp, Vec) // AVX + ASMJIT_INST_2x(vmovmskps, Vmovmskps, Gp, Vec) // AVX + ASMJIT_INST_2x(vmovntdq, Vmovntdq, Mem, Vec) // AVX+ AVX512_F + ASMJIT_INST_2x(vmovntdqa, Vmovntdqa, Vec, Mem) // AVX+ AVX512_F + ASMJIT_INST_2x(vmovntpd, Vmovntpd, Mem, Vec) // AVX AVX512_F + ASMJIT_INST_2x(vmovntps, Vmovntps, Mem, Vec) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Gp, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Xmm, Gp) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Xmm, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovsd, Vmovsd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovsd, Vmovsd, Xmm, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vmovsd, Vmovsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovshdup, Vmovshdup, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovshdup, Vmovshdup, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovsldup, Vmovsldup, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovsldup, Vmovsldup, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovss, Vmovss, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovss, Vmovss, Xmm, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vmovss, Vmovss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovupd, Vmovupd, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovupd, Vmovupd, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovupd, Vmovupd, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovups, Vmovups, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovups, Vmovups, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovups, Vmovups, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_4x(vmpsadbw, Vmpsadbw, Vec, Vec, Vec, Imm) // AVX+ + ASMJIT_INST_4x(vmpsadbw, Vmpsadbw, Vec, Vec, Mem, Imm) // AVX+ + ASMJIT_INST_3x(vmulpd, Vmulpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmulpd, Vmulpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmulps, Vmulps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmulps, Vmulps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmulsd, Vmulsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vmulsd, Vmulsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vmulss, Vmulss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vmulss, Vmulss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vorpd, Vorpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vorpd, Vorpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vorps, Vorps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vorps, Vorps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_4x(vp2intersectd, Vp2intersectd, KReg, KReg, Vec, Vec) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_4x(vp2intersectd, Vp2intersectd, KReg, KReg, Vec, Mem) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_4x(vp2intersectq, Vp2intersectq, KReg, KReg, Vec, Vec) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_4x(vp2intersectq, Vp2intersectq, KReg, KReg, Vec, Mem) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_6x(vp4dpwssd, Vp4dpwssd, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(vp4dpwssds, Vp4dpwssds, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_2x(vpabsb, Vpabsb, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpabsb, Vpabsb, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpabsd, Vpabsd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpabsd, Vpabsd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpabsq, Vpabsq, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpabsq, Vpabsq, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vpabsw, Vpabsw, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpabsw, Vpabsw, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpackssdw, Vpackssdw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpackssdw, Vpackssdw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpacksswb, Vpacksswb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpacksswb, Vpacksswb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpackusdw, Vpackusdw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpackusdw, Vpackusdw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpackuswb, Vpackuswb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpackuswb, Vpackuswb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddb, Vpaddb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddb, Vpaddb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddd, Vpaddd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpaddd, Vpaddd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpaddq, Vpaddq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpaddq, Vpaddq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpaddsb, Vpaddsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddsb, Vpaddsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddsw, Vpaddsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddsw, Vpaddsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusb, Vpaddusb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusb, Vpaddusb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusw, Vpaddusw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusw, Vpaddusw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddw, Vpaddw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddw, Vpaddw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_4x(vpalignr, Vpalignr, Vec, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_4x(vpalignr, Vpalignr, Vec, Vec, Mem, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpand, Vpand, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpand, Vpand, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpandd, Vpandd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandd, Vpandd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandn, Vpandn, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpandn, Vpandn, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpandnd, Vpandnd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandnd, Vpandnd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandnq, Vpandnq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpandnq, Vpandnq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpandq, Vpandq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpandq, Vpandq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpavgb, Vpavgb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpavgb, Vpavgb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpavgw, Vpavgw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpavgw, Vpavgw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_4x(vpblendd, Vpblendd, Vec, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_4x(vpblendd, Vpblendd, Vec, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_3x(vpblendmb, Vpblendmb, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpblendmb, Vpblendmb, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpblendmd, Vpblendmd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpblendmd, Vpblendmd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpblendmq, Vpblendmq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpblendmq, Vpblendmq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpblendmw, Vpblendmw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpblendmw, Vpblendmw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_4x(vpblendvb, Vpblendvb, Vec, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_4x(vpblendvb, Vpblendvb, Vec, Vec, Mem, Vec) // AVX+ + ASMJIT_INST_4x(vpblendw, Vpblendw, Vec, Vec, Vec, Imm) // AVX+ + ASMJIT_INST_4x(vpblendw, Vpblendw, Vec, Vec, Mem, Imm) // AVX+ + ASMJIT_INST_2x(vpbroadcastb, Vpbroadcastb, Vec, Vec) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastb, Vpbroadcastb, Vec, Mem) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastb, Vpbroadcastb, Vec, Gp) // AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastd, Vpbroadcastd, Vec, Vec) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastd, Vpbroadcastd, Vec, Mem) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastd, Vpbroadcastd, Vec, Gp) // AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastmb2q, Vpbroadcastmb2q, Vec, KReg) // AVX512_CD + ASMJIT_INST_2x(vpbroadcastmw2d, Vpbroadcastmw2d, Vec, KReg) // AVX512_CD + ASMJIT_INST_2x(vpbroadcastq, Vpbroadcastq, Vec, Vec) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastq, Vpbroadcastq, Vec, Mem) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastq, Vpbroadcastq, Vec, Gp) // AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastw, Vpbroadcastw, Vec, Vec) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastw, Vpbroadcastw, Vec, Mem) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastw, Vpbroadcastw, Vec, Gp) // AVX512_BW{kz} + ASMJIT_INST_4x(vpclmulqdq, Vpclmulqdq, Vec, Vec, Vec, Imm) // AVX VPCLMULQDQ AVX512_F + ASMJIT_INST_4x(vpclmulqdq, Vpclmulqdq, Vec, Vec, Mem, Imm) // AVX VPCLMULQDQ AVX512_F + ASMJIT_INST_4x(vpcmpb, Vpcmpb, KReg, Vec, Vec, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpb, Vpcmpb, KReg, Vec, Mem, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpd, Vpcmpd, KReg, Vec, Vec, Imm) // AVX512_F{k|b32} + ASMJIT_INST_4x(vpcmpd, Vpcmpd, KReg, Vec, Mem, Imm) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_6x(vpcmpestri, Vpcmpestri, Vec, Vec, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_6x(vpcmpestri, Vpcmpestri, Vec, Mem, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_6x(vpcmpestrm, Vpcmpestrm, Vec, Vec, Imm, XMM0, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_6x(vpcmpestrm, Vpcmpestrm, Vec, Mem, Imm, XMM0, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpistri, Vpcmpistri, Vec, Vec, Imm, Gp_ECX) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpistri, Vpcmpistri, Vec, Mem, Imm, Gp_ECX) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpistrm, Vpcmpistrm, Vec, Vec, Imm, XMM0) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpistrm, Vpcmpistrm, Vec, Mem, Imm, XMM0) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpq, Vpcmpq, KReg, Vec, Vec, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpq, Vpcmpq, KReg, Vec, Mem, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpub, Vpcmpub, KReg, Vec, Vec, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpub, Vpcmpub, KReg, Vec, Mem, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpud, Vpcmpud, KReg, Vec, Vec, Imm) // AVX512_F{k|b32} + ASMJIT_INST_4x(vpcmpud, Vpcmpud, KReg, Vec, Mem, Imm) // AVX512_F{k|b32} + ASMJIT_INST_4x(vpcmpuq, Vpcmpuq, KReg, Vec, Vec, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpuq, Vpcmpuq, KReg, Vec, Mem, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpuw, Vpcmpuw, KReg, Vec, Vec, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_4x(vpcmpuw, Vpcmpuw, KReg, Vec, Mem, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_4x(vpcmpw, Vpcmpw, KReg, Vec, Vec, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_4x(vpcmpw, Vpcmpw, KReg, Vec, Mem, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_2x(vpcompressb, Vpcompressb, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpcompressb, Vpcompressb, Mem, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpcompressd, Vpcompressd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressd, Vpcompressd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressq, Vpcompressq, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressq, Vpcompressq, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressw, Vpcompressw, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpcompressw, Vpcompressw, Mem, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpconflictd, Vpconflictd, Vec, Vec) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vpconflictd, Vpconflictd, Vec, Mem) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vpconflictq, Vpconflictq, Vec, Vec) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vpconflictq, Vpconflictq, Vec, Mem) // AVX512_CD{kz|b32} + ASMJIT_INST_3x(vpdpbusd, Vpdpbusd, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpbusd, Vpdpbusd, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpbusds, Vpdpbusds, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpbusds, Vpdpbusds, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssd, Vpdpwssd, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssd, Vpdpwssd, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssds, Vpdpwssds, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssds, Vpdpwssds, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_4x(vperm2f128, Vperm2f128, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vperm2f128, Vperm2f128, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vperm2i128, Vperm2i128, Vec, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_4x(vperm2i128, Vperm2i128, Vec, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_3x(vpermb, Vpermb, Vec, Vec, Vec) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermb, Vpermb, Vec, Vec, Mem) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermd, Vpermd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermd, Vpermd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2b, Vpermi2b, Vec, Vec, Vec) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermi2b, Vpermi2b, Vec, Vec, Mem) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermi2d, Vpermi2d, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2d, Vpermi2d, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2pd, Vpermi2pd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2pd, Vpermi2pd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2ps, Vpermi2ps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2ps, Vpermi2ps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2q, Vpermi2q, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2q, Vpermi2q, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2w, Vpermi2w, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermi2w, Vpermi2w, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Mem, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Mem, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermps, Vpermps, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermps, Vpermps, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Vec, Imm) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Mem, Imm) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2b, Vpermt2b, Vec, Vec, Vec) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermt2b, Vpermt2b, Vec, Vec, Mem) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermt2d, Vpermt2d, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2d, Vpermt2d, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2pd, Vpermt2pd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2pd, Vpermt2pd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2ps, Vpermt2ps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2ps, Vpermt2ps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2q, Vpermt2q, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2q, Vpermt2q, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2w, Vpermt2w, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermt2w, Vpermt2w, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermw, Vpermw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermw, Vpermw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_2x(vpexpandb, Vpexpandb, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpexpandb, Vpexpandb, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpexpandd, Vpexpandd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandd, Vpexpandd, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandq, Vpexpandq, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandq, Vpexpandq, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandw, Vpexpandw, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpexpandw, Vpexpandw, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpextrb, Vpextrb, Gp, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_3x(vpextrb, Vpextrb, Mem, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_3x(vpextrd, Vpextrd, Gp, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrd, Vpextrd, Mem, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrq, Vpextrq, Gp, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrq, Vpextrq, Mem, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrw, Vpextrw, Gp, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_3x(vpextrw, Vpextrw, Mem, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_2x(vpgatherdd, Vpgatherdd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherdd, Vpgatherdd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vpgatherdq, Vpgatherdq, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherdq, Vpgatherdq, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vpgatherqd, Vpgatherqd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherqd, Vpgatherqd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vpgatherqq, Vpgatherqq, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherqq, Vpgatherqq, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_3x(vphaddd, Vphaddd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphaddd, Vphaddd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphaddsw, Vphaddsw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphaddsw, Vphaddsw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphaddw, Vphaddw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphaddw, Vphaddw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_2x(vphminposuw, Vphminposuw, Vec, Vec) // AVX + ASMJIT_INST_2x(vphminposuw, Vphminposuw, Vec, Mem) // AVX + ASMJIT_INST_3x(vphsubd, Vphsubd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphsubd, Vphsubd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphsubsw, Vphsubsw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphsubsw, Vphsubsw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphsubw, Vphsubw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphsubw, Vphsubw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_4x(vpinsrb, Vpinsrb, Xmm, Xmm, Gp, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpinsrb, Vpinsrb, Xmm, Xmm, Mem, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpinsrd, Vpinsrd, Xmm, Xmm, Gp, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrd, Vpinsrd, Xmm, Xmm, Mem, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrq, Vpinsrq, Xmm, Xmm, Gp, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrq, Vpinsrq, Xmm, Xmm, Mem, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrw, Vpinsrw, Xmm, Xmm, Gp, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpinsrw, Vpinsrw, Xmm, Xmm, Mem, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_2x(vplzcntd, Vplzcntd, Vec, Vec) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vplzcntd, Vplzcntd, Vec, Mem) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vplzcntq, Vplzcntq, Vec, Vec) // AVX512_CD{kz|b64} + ASMJIT_INST_2x(vplzcntq, Vplzcntq, Vec, Mem) // AVX512_CD{kz|b64} + ASMJIT_INST_3x(vpmadd52huq, Vpmadd52huq, Vec, Vec, Vec) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmadd52huq, Vpmadd52huq, Vec, Vec, Mem) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmadd52luq, Vpmadd52luq, Vec, Vec, Vec) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmadd52luq, Vpmadd52luq, Vec, Vec, Mem) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmaddubsw, Vpmaddubsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaddubsw, Vpmaddubsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaddwd, Vpmaddwd, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaddwd, Vpmaddwd, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaskmovd, Vpmaskmovd, Mem, Vec, Vec) // AVX2 + ASMJIT_INST_3x(vpmaskmovd, Vpmaskmovd, Vec, Vec, Mem) // AVX2 + ASMJIT_INST_3x(vpmaskmovq, Vpmaskmovq, Mem, Vec, Vec) // AVX2 + ASMJIT_INST_3x(vpmaskmovq, Vpmaskmovq, Vec, Vec, Mem) // AVX2 + ASMJIT_INST_3x(vpmaxsb, Vpmaxsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxsb, Vpmaxsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxsd, Vpmaxsd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxsd, Vpmaxsd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxsq, Vpmaxsq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxsq, Vpmaxsq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxsw, Vpmaxsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxsw, Vpmaxsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxub, Vpmaxub, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxub, Vpmaxub, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxud, Vpmaxud, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxud, Vpmaxud, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxuq, Vpmaxuq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxuq, Vpmaxuq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxuw, Vpmaxuw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxuw, Vpmaxuw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsb, Vpminsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsb, Vpminsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsd, Vpminsd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminsd, Vpminsd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminsq, Vpminsq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminsq, Vpminsq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminsw, Vpminsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsw, Vpminsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminub, Vpminub, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminub, Vpminub, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminud, Vpminud, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminud, Vpminud, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminuq, Vpminuq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminuq, Vpminuq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminuw, Vpminuw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminuw, Vpminuw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovb2m, Vpmovb2m, KReg, Vec) // AVX512_BW + ASMJIT_INST_2x(vpmovd2m, Vpmovd2m, KReg, Vec) // AVX512_DQ + ASMJIT_INST_2x(vpmovdb, Vpmovdb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovdb, Vpmovdb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovdw, Vpmovdw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovdw, Vpmovdw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovm2b, Vpmovm2b, Vec, KReg) // AVX512_BW + ASMJIT_INST_2x(vpmovm2d, Vpmovm2d, Vec, KReg) // AVX512_DQ + ASMJIT_INST_2x(vpmovm2q, Vpmovm2q, Vec, KReg) // AVX512_DQ + ASMJIT_INST_2x(vpmovm2w, Vpmovm2w, Vec, KReg) // AVX512_BW + ASMJIT_INST_2x(vpmovmskb, Vpmovmskb, Gp, Vec) // AVX+ + ASMJIT_INST_2x(vpmovq2m, Vpmovq2m, KReg, Vec) // AVX512_DQ + ASMJIT_INST_2x(vpmovqb, Vpmovqb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqb, Vpmovqb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqd, Vpmovqd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqd, Vpmovqd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqw, Vpmovqw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqw, Vpmovqw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdb, Vpmovsdb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdb, Vpmovsdb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdw, Vpmovsdw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdw, Vpmovsdw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqb, Vpmovsqb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqb, Vpmovsqb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqd, Vpmovsqd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqd, Vpmovsqd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqw, Vpmovsqw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqw, Vpmovsqw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovswb, Vpmovswb, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovswb, Vpmovswb, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovsxbd, Vpmovsxbd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbd, Vpmovsxbd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbq, Vpmovsxbq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbq, Vpmovsxbq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbw, Vpmovsxbw, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovsxbw, Vpmovsxbw, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovsxdq, Vpmovsxdq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxdq, Vpmovsxdq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwd, Vpmovsxwd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwd, Vpmovsxwd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwq, Vpmovsxwq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwq, Vpmovsxwq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdb, Vpmovusdb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdb, Vpmovusdb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdw, Vpmovusdw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdw, Vpmovusdw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqb, Vpmovusqb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqb, Vpmovusqb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqd, Vpmovusqd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqd, Vpmovusqd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqw, Vpmovusqw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqw, Vpmovusqw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovuswb, Vpmovuswb, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovuswb, Vpmovuswb, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovw2m, Vpmovw2m, KReg, Vec) // AVX512_BW + ASMJIT_INST_2x(vpmovwb, Vpmovwb, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovwb, Vpmovwb, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovzxbd, Vpmovzxbd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbd, Vpmovzxbd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbq, Vpmovzxbq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbq, Vpmovzxbq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbw, Vpmovzxbw, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovzxbw, Vpmovzxbw, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovzxdq, Vpmovzxdq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxdq, Vpmovzxdq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwd, Vpmovzxwd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwd, Vpmovzxwd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwq, Vpmovzxwq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwq, Vpmovzxwq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpmuldq, Vpmuldq, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmuldq, Vpmuldq, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmulhrsw, Vpmulhrsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhrsw, Vpmulhrsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhuw, Vpmulhuw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhuw, Vpmulhuw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhw, Vpmulhw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhw, Vpmulhw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulld, Vpmulld, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmulld, Vpmulld, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmullq, Vpmullq, Vec, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vpmullq, Vpmullq, Vec, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vpmullw, Vpmullw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmullw, Vpmullw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmultishiftqb, Vpmultishiftqb, Vec, Vec, Vec) // AVX512_VBMI{kz|b64} + ASMJIT_INST_3x(vpmultishiftqb, Vpmultishiftqb, Vec, Vec, Mem) // AVX512_VBMI{kz|b64} + ASMJIT_INST_3x(vpmuludq, Vpmuludq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmuludq, Vpmuludq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_2x(vpopcntb, Vpopcntb, Vec, Vec) // AVX512_BITALG{kz|b32} + ASMJIT_INST_2x(vpopcntb, Vpopcntb, Vec, Mem) // AVX512_BITALG{kz|b32} + ASMJIT_INST_2x(vpopcntd, Vpopcntd, Vec, Vec) // AVX512_VPOPCNTDQ{kz|b32} + ASMJIT_INST_2x(vpopcntd, Vpopcntd, Vec, Mem) // AVX512_VPOPCNTDQ{kz|b32} + ASMJIT_INST_2x(vpopcntq, Vpopcntq, Vec, Vec) // AVX512_VPOPCNTDQ{kz|b64} + ASMJIT_INST_2x(vpopcntq, Vpopcntq, Vec, Mem) // AVX512_VPOPCNTDQ{kz|b64} + ASMJIT_INST_2x(vpopcntw, Vpopcntw, Vec, Vec) // AVX512_BITALG{kz|b32} + ASMJIT_INST_2x(vpopcntw, Vpopcntw, Vec, Mem) // AVX512_BITALG{kz|b32} + ASMJIT_INST_3x(vpor, Vpor, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpor, Vpor, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpord, Vpord, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpord, Vpord, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vporq, Vporq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vporq, Vporq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprold, Vprold, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprold, Vprold, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprolq, Vprolq, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprolq, Vprolq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprolvd, Vprolvd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprolvd, Vprolvd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprolvq, Vprolvq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprolvq, Vprolvq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprord, Vprord, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprord, Vprord, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprorq, Vprorq, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprorq, Vprorq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprorvd, Vprorvd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprorvd, Vprorvd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprorvq, Vprorvq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprorvq, Vprorvq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsadbw, Vpsadbw, Vec, Vec, Vec) // AVX+ AVX512_BW + ASMJIT_INST_3x(vpsadbw, Vpsadbw, Vec, Vec, Mem) // AVX+ AVX512_BW + ASMJIT_INST_2x(vpscatterdd, Vpscatterdd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vpscatterdq, Vpscatterdq, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vpscatterqd, Vpscatterqd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vpscatterqq, Vpscatterqq, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_4x(vpshldd, Vpshldd, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldd, Vpshldd, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldq, Vpshldq, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldq, Vpshldq, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvd, Vpshldvd, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvd, Vpshldvd, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvq, Vpshldvq, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvq, Vpshldvq, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvw, Vpshldvw, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvw, Vpshldvw, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldw, Vpshldw, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldw, Vpshldw, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdd, Vpshrdd, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdd, Vpshrdd, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdq, Vpshrdq, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdq, Vpshrdq, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvd, Vpshrdvd, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvd, Vpshrdvd, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvq, Vpshrdvq, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvq, Vpshrdvq, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvw, Vpshrdvw, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvw, Vpshrdvw, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdw, Vpshrdw, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdw, Vpshrdw, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshufb, Vpshufb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshufb, Vpshufb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshufbitqmb, Vpshufbitqmb, KReg, Vec, Vec) // AVX512_BITALG{k} + ASMJIT_INST_3x(vpshufbitqmb, Vpshufbitqmb, KReg, Vec, Mem) // AVX512_BITALG{k} + ASMJIT_INST_3x(vpshufd, Vpshufd, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpshufd, Vpshufd, Vec, Mem, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpshufhw, Vpshufhw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshufhw, Vpshufhw, Vec, Mem, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshuflw, Vpshuflw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshuflw, Vpshuflw, Vec, Mem, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsignb, Vpsignb, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpsignb, Vpsignb, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpsignd, Vpsignd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpsignd, Vpsignd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpsignw, Vpsignw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpsignw, Vpsignw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpslldq, Vpslldq, Vec, Vec, Imm) // AVX+ AVX512_BW + ASMJIT_INST_3x(vpslldq, Vpslldq, Vec, Mem, Imm) // AVX512_BW + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllvd, Vpsllvd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsllvd, Vpsllvd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsllvq, Vpsllvq, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllvq, Vpsllvq, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllvw, Vpsllvw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsllvw, Vpsllvw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsravd, Vpsravd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsravd, Vpsravd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsravq, Vpsravq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsravq, Vpsravq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsravw, Vpsravw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsravw, Vpsravw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrldq, Vpsrldq, Vec, Vec, Imm) // AVX+ AVX512_BW + ASMJIT_INST_3x(vpsrldq, Vpsrldq, Vec, Mem, Imm) // AVX512_BW + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlvd, Vpsrlvd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrlvd, Vpsrlvd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrlvq, Vpsrlvq, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlvq, Vpsrlvq, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlvw, Vpsrlvw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlvw, Vpsrlvw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsubb, Vpsubb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubb, Vpsubb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubd, Vpsubd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsubd, Vpsubd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsubq, Vpsubq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsubq, Vpsubq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsubsb, Vpsubsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubsb, Vpsubsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubsw, Vpsubsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubsw, Vpsubsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusb, Vpsubusb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusb, Vpsubusb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusw, Vpsubusw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusw, Vpsubusw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubw, Vpsubw, Vec, Vec, Vec) // AVX AVX512_BW{kz} + ASMJIT_INST_3x(vpsubw, Vpsubw, Vec, Vec, Mem) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpternlogd, Vpternlogd, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vpternlogd, Vpternlogd, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vpternlogq, Vpternlogq, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vpternlogq, Vpternlogq, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vptest, Vptest, Vec, Vec) // AVX + ASMJIT_INST_2x(vptest, Vptest, Vec, Mem) // AVX + ASMJIT_INST_3x(vptestmb, Vptestmb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestmb, Vptestmb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vptestmd, Vptestmd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestmd, Vptestmd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestmq, Vptestmq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestmq, Vptestmq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestmw, Vptestmw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestmw, Vptestmw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmb, Vptestnmb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmb, Vptestnmb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmd, Vptestnmd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestnmd, Vptestnmd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestnmq, Vptestnmq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestnmq, Vptestnmq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestnmw, Vptestnmw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmw, Vptestnmw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vpunpckhbw, Vpunpckhbw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckhbw, Vpunpckhbw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckhdq, Vpunpckhdq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpckhdq, Vpunpckhdq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpckhqdq, Vpunpckhqdq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpckhqdq, Vpunpckhqdq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpckhwd, Vpunpckhwd, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckhwd, Vpunpckhwd, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpcklbw, Vpunpcklbw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpcklbw, Vpunpcklbw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckldq, Vpunpckldq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpckldq, Vpunpckldq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpcklqdq, Vpunpcklqdq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpcklqdq, Vpunpcklqdq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpcklwd, Vpunpcklwd, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpcklwd, Vpunpcklwd, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpxor, Vpxor, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpxor, Vpxor, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpxord, Vpxord, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpxord, Vpxord, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpxorq, Vpxorq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpxorq, Vpxorq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vrangepd, Vrangepd, Vec, Vec, Vec, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_4x(vrangepd, Vrangepd, Vec, Vec, Mem, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_4x(vrangeps, Vrangeps, Vec, Vec, Vec, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_4x(vrangeps, Vrangeps, Vec, Vec, Mem, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_4x(vrangesd, Vrangesd, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_4x(vrangesd, Vrangesd, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_4x(vrangess, Vrangess, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_4x(vrangess, Vrangess, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_2x(vrcp14pd, Vrcp14pd, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrcp14pd, Vrcp14pd, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrcp14ps, Vrcp14ps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vrcp14ps, Vrcp14ps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vrcp14sd, Vrcp14sd, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrcp14sd, Vrcp14sd, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vrcp14ss, Vrcp14ss, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrcp14ss, Vrcp14ss, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vrcp28pd, Vrcp28pd, Vec, Vec) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrcp28pd, Vrcp28pd, Vec, Mem) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrcp28ps, Vrcp28ps, Vec, Vec) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vrcp28ps, Vrcp28ps, Vec, Mem) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_3x(vrcp28sd, Vrcp28sd, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrcp28sd, Vrcp28sd, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrcp28ss, Vrcp28ss, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrcp28ss, Vrcp28ss, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_2x(vrcpps, Vrcpps, Vec, Vec) // AVX + ASMJIT_INST_2x(vrcpps, Vrcpps, Vec, Mem) // AVX + ASMJIT_INST_3x(vrcpss, Vrcpss, Xmm, Xmm, Xmm) // AVX + ASMJIT_INST_3x(vrcpss, Vrcpss, Xmm, Xmm, Mem) // AVX + ASMJIT_INST_3x(vreducepd, Vreducepd, Vec, Vec, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vreducepd, Vreducepd, Vec, Mem, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vreduceps, Vreduceps, Vec, Vec, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vreduceps, Vreduceps, Vec, Mem, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_4x(vreducesd, Vreducesd, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vreducesd, Vreducesd, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vreducess, Vreducess, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vreducess, Vreducess, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vrndscalepd, Vrndscalepd, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vrndscalepd, Vrndscalepd, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vrndscaleps, Vrndscaleps, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vrndscaleps, Vrndscaleps, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vrndscalesd, Vrndscalesd, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vrndscalesd, Vrndscalesd, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vrndscaless, Vrndscaless, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vrndscaless, Vrndscaless, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vroundpd, Vroundpd, Vec, Vec, Imm) // AVX + ASMJIT_INST_3x(vroundpd, Vroundpd, Vec, Mem, Imm) // AVX + ASMJIT_INST_3x(vroundps, Vroundps, Vec, Vec, Imm) // AVX + ASMJIT_INST_3x(vroundps, Vroundps, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vroundsd, Vroundsd, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vroundsd, Vroundsd, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_4x(vroundss, Vroundss, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vroundss, Vroundss, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_2x(vrsqrt14pd, Vrsqrt14pd, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrsqrt14pd, Vrsqrt14pd, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrsqrt14ps, Vrsqrt14ps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vrsqrt14ps, Vrsqrt14ps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vrsqrt14sd, Vrsqrt14sd, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrsqrt14sd, Vrsqrt14sd, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vrsqrt14ss, Vrsqrt14ss, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrsqrt14ss, Vrsqrt14ss, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vrsqrt28pd, Vrsqrt28pd, Vec, Vec) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrsqrt28pd, Vrsqrt28pd, Vec, Mem) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrsqrt28ps, Vrsqrt28ps, Vec, Vec) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vrsqrt28ps, Vrsqrt28ps, Vec, Mem) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_3x(vrsqrt28sd, Vrsqrt28sd, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrsqrt28sd, Vrsqrt28sd, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrsqrt28ss, Vrsqrt28ss, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrsqrt28ss, Vrsqrt28ss, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_2x(vrsqrtps, Vrsqrtps, Vec, Vec) // AVX + ASMJIT_INST_2x(vrsqrtps, Vrsqrtps, Vec, Mem) // AVX + ASMJIT_INST_3x(vrsqrtss, Vrsqrtss, Xmm, Xmm, Xmm) // AVX + ASMJIT_INST_3x(vrsqrtss, Vrsqrtss, Xmm, Xmm, Mem) // AVX + ASMJIT_INST_3x(vscalefpd, Vscalefpd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vscalefpd, Vscalefpd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vscalefps, Vscalefps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vscalefps, Vscalefps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vscalefsd, Vscalefsd, Xmm, Xmm, Xmm) // AVX512_F{kz|er} + ASMJIT_INST_3x(vscalefsd, Vscalefsd, Xmm, Xmm, Mem) // AVX512_F{kz|er} + ASMJIT_INST_3x(vscalefss, Vscalefss, Xmm, Xmm, Xmm) // AVX512_F{kz|er} + ASMJIT_INST_3x(vscalefss, Vscalefss, Xmm, Xmm, Mem) // AVX512_F{kz|er} + ASMJIT_INST_2x(vscatterdpd, Vscatterdpd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vscatterdps, Vscatterdps, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_1x(vscatterpf0dpd, Vscatterpf0dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf0dps, Vscatterpf0dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf0qpd, Vscatterpf0qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf0qps, Vscatterpf0qps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1dpd, Vscatterpf1dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1dps, Vscatterpf1dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1qpd, Vscatterpf1qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1qps, Vscatterpf1qps, Mem) // AVX512_PF{k} + ASMJIT_INST_2x(vscatterqpd, Vscatterqpd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vscatterqps, Vscatterqps, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_4x(vshuff32x4, Vshuff32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshuff32x4, Vshuff32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshuff64x2, Vshuff64x2, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshuff64x2, Vshuff64x2, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufi32x4, Vshufi32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufi32x4, Vshufi32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufi64x2, Vshufi64x2, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufi64x2, Vshufi64x2, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufpd, Vshufpd, Vec, Vec, Vec, Imm) // AVX AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufpd, Vshufpd, Vec, Vec, Mem, Imm) // AVX AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufps, Vshufps, Vec, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufps, Vshufps, Vec, Vec, Mem, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vsqrtpd, Vsqrtpd, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vsqrtpd, Vsqrtpd, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vsqrtps, Vsqrtps, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vsqrtps, Vsqrtps, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vsqrtsd, Vsqrtsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsqrtsd, Vsqrtsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsqrtss, Vsqrtss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsqrtss, Vsqrtss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_1x(vstmxcsr, Vstmxcsr, Mem) // AVX + ASMJIT_INST_3x(vsubpd, Vsubpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vsubpd, Vsubpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vsubps, Vsubps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vsubps, Vsubps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vsubsd, Vsubsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsubsd, Vsubsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsubss, Vsubss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsubss, Vsubss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_2x(vtestpd, Vtestpd, Vec, Vec) // AVX + ASMJIT_INST_2x(vtestpd, Vtestpd, Vec, Mem) // AVX + ASMJIT_INST_2x(vtestps, Vtestps, Vec, Vec) // AVX + ASMJIT_INST_2x(vtestps, Vtestps, Vec, Mem) // AVX + ASMJIT_INST_2x(vucomisd, Vucomisd, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vucomisd, Vucomisd, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vucomiss, Vucomiss, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vucomiss, Vucomiss, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_3x(vunpckhpd, Vunpckhpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpckhpd, Vunpckhpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpckhps, Vunpckhps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vunpckhps, Vunpckhps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vunpcklpd, Vunpcklpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpcklpd, Vunpcklpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpcklps, Vunpcklps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vunpcklps, Vunpcklps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vxorpd, Vxorpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vxorpd, Vxorpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vxorps, Vxorps, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vxorps, Vxorps, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_0x(vzeroall, Vzeroall) // AVX + ASMJIT_INST_0x(vzeroupper, Vzeroupper) // AVX + + //! \} + + //! \name FMA4 Instructions + //! \{ + + ASMJIT_INST_4x(vfmaddpd, Vfmaddpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddpd, Vfmaddpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddpd, Vfmaddpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddps, Vfmaddps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddps, Vfmaddps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddps, Vfmaddps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddsd, Vfmaddsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddsd, Vfmaddsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddsd, Vfmaddsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddss, Vfmaddss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddss, Vfmaddss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddss, Vfmaddss, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddsubpd, Vfmaddsubpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubpd, Vfmaddsubpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubpd, Vfmaddsubpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddsubps, Vfmaddsubps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubps, Vfmaddsubps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubps, Vfmaddsubps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubaddpd, Vfmsubaddpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddpd, Vfmsubaddpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddpd, Vfmsubaddpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubaddps, Vfmsubaddps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddps, Vfmsubaddps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddps, Vfmsubaddps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubpd, Vfmsubpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubpd, Vfmsubpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubpd, Vfmsubpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubps, Vfmsubps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubps, Vfmsubps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubps, Vfmsubps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubsd, Vfmsubsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubsd, Vfmsubsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubsd, Vfmsubsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubss, Vfmsubss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubss, Vfmsubss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubss, Vfmsubss, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddpd, Vfnmaddpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddpd, Vfnmaddpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddpd, Vfnmaddpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddps, Vfnmaddps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddps, Vfnmaddps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddps, Vfnmaddps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddsd, Vfnmaddsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddsd, Vfnmaddsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddsd, Vfnmaddsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddss, Vfnmaddss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddss, Vfnmaddss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddss, Vfnmaddss, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubpd, Vfnmsubpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubpd, Vfnmsubpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubpd, Vfnmsubpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubps, Vfnmsubps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubps, Vfnmsubps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubps, Vfnmsubps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubsd, Vfnmsubsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubsd, Vfnmsubsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubsd, Vfnmsubsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubss, Vfnmsubss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubss, Vfnmsubss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubss, Vfnmsubss, Xmm, Xmm, Xmm, Mem) // FMA4 + + //! \} + + //! \name XOP Instructions (Deprecated) + //! \{ + + ASMJIT_INST_2x(vfrczpd, Vfrczpd, Vec, Vec) // XOP + ASMJIT_INST_2x(vfrczpd, Vfrczpd, Vec, Mem) // XOP + ASMJIT_INST_2x(vfrczps, Vfrczps, Vec, Vec) // XOP + ASMJIT_INST_2x(vfrczps, Vfrczps, Vec, Mem) // XOP + ASMJIT_INST_2x(vfrczsd, Vfrczsd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vfrczsd, Vfrczsd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vfrczss, Vfrczss, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vfrczss, Vfrczss, Xmm, Mem) // XOP + ASMJIT_INST_4x(vpcmov, Vpcmov, Vec, Vec, Vec, Vec) // XOP + ASMJIT_INST_4x(vpcmov, Vpcmov, Vec, Vec, Mem, Vec) // XOP + ASMJIT_INST_4x(vpcmov, Vpcmov, Vec, Vec, Vec, Mem) // XOP + ASMJIT_INST_4x(vpcomb, Vpcomb, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomb, Vpcomb, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomd, Vpcomd, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomd, Vpcomd, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomq, Vpcomq, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomq, Vpcomq, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomw, Vpcomw, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomw, Vpcomw, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomub, Vpcomub, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomub, Vpcomub, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomud, Vpcomud, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomud, Vpcomud, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomuq, Vpcomuq, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomuq, Vpcomuq, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomuw, Vpcomuw, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomuw, Vpcomuw, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_5x(vpermil2pd, Vpermil2pd, Vec, Vec, Vec, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2pd, Vpermil2pd, Vec, Vec, Mem, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2pd, Vpermil2pd, Vec, Vec, Vec, Mem, Imm) // XOP + ASMJIT_INST_5x(vpermil2ps, Vpermil2ps, Vec, Vec, Vec, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2ps, Vpermil2ps, Vec, Vec, Mem, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2ps, Vpermil2ps, Vec, Vec, Vec, Mem, Imm) // XOP + ASMJIT_INST_2x(vphaddbd, Vphaddbd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddbd, Vphaddbd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddbq, Vphaddbq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddbq, Vphaddbq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddbw, Vphaddbw, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddbw, Vphaddbw, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphadddq, Vphadddq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphadddq, Vphadddq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddwd, Vphaddwd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddwd, Vphaddwd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddwq, Vphaddwq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddwq, Vphaddwq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddubd, Vphaddubd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddubd, Vphaddubd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddubq, Vphaddubq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddubq, Vphaddubq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddubw, Vphaddubw, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddubw, Vphaddubw, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddudq, Vphaddudq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddudq, Vphaddudq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphadduwd, Vphadduwd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphadduwd, Vphadduwd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphadduwq, Vphadduwq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphadduwq, Vphadduwq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphsubbw, Vphsubbw, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphsubbw, Vphsubbw, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphsubdq, Vphsubdq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphsubdq, Vphsubdq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphsubwd, Vphsubwd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphsubwd, Vphsubwd, Xmm, Mem) // XOP + ASMJIT_INST_4x(vpmacsdd, Vpmacsdd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdd, Vpmacsdd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdqh, Vpmacsdqh, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdqh, Vpmacsdqh, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdql, Vpmacsdql, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdql, Vpmacsdql, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacswd, Vpmacswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacswd, Vpmacswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsww, Vpmacsww, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsww, Vpmacsww, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdd, Vpmacssdd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdd, Vpmacssdd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdqh, Vpmacssdqh, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdqh, Vpmacssdqh, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdql, Vpmacssdql, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdql, Vpmacssdql, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsswd, Vpmacsswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsswd, Vpmacsswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssww, Vpmacssww, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssww, Vpmacssww, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmadcsswd, Vpmadcsswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmadcsswd, Vpmadcsswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmadcswd, Vpmadcswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmadcswd, Vpmadcswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpperm, Vpperm, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpperm, Vpperm, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpperm, Vpperm, Xmm, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vpshab, Vpshab, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshab, Vpshab, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshab, Vpshab, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshad, Vpshad, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshad, Vpshad, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshad, Vpshad, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshaq, Vpshaq, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshaq, Vpshaq, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshaq, Vpshaq, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshaw, Vpshaw, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshaw, Vpshaw, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshaw, Vpshaw, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshlb, Vpshlb, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshlb, Vpshlb, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshlb, Vpshlb, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshld, Vpshld, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshld, Vpshld, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshld, Vpshld, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshlq, Vpshlq, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshlq, Vpshlq, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshlq, Vpshlq, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshlw, Vpshlw, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshlw, Vpshlw, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshlw, Vpshlw, Xmm, Xmm, Mem) // XOP + + //! \} + + //! \name AVX_NE_CONVERT Instructions + //! \{ + + ASMJIT_INST_2x(vbcstnebf162ps, Vbcstnebf162ps, Vec, Mem) + ASMJIT_INST_2x(vbcstnesh2ps, Vbcstnesh2ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneebf162ps, Vcvtneebf162ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneeph2ps, Vcvtneeph2ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneobf162ps, Vcvtneobf162ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneoph2ps, Vcvtneoph2ps, Vec, Mem) + + //! \} + + //! \name AVX_VNNI_INT8 Instructions + //! \{ + + ASMJIT_INST_3x(vpdpbssd, Vpdpbssd, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbssd, Vpdpbssd, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbssds, Vpdpbssds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbssds, Vpdpbssds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbsud, Vpdpbsud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbsud, Vpdpbsud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbsuds, Vpdpbsuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbsuds, Vpdpbsuds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbuud, Vpdpbuud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbuud, Vpdpbuud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbuuds, Vpdpbuuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbuuds, Vpdpbuuds, Vec, Vec, Mem) + + //! \} + + //! \name AVX_VNNI_INT16 Instructions + //! \{ + + ASMJIT_INST_3x(vpdpwsud, Vpdpwsud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwsud, Vpdpwsud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwsuds, Vpdpwsuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwsuds, Vpdpwsuds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwusd, Vpdpwusd, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwusd, Vpdpwusd, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwusds, Vpdpwusds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwusds, Vpdpwusds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwuud, Vpdpwuud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwuud, Vpdpwuud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwuuds, Vpdpwuuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwuuds, Vpdpwuuds, Vec, Vec, Mem) + + //! \} + + //! \name AVX+SHA512 Instructions + //! \{ + ASMJIT_INST_2x(vsha512msg1, Vsha512msg1, Vec, Vec) + ASMJIT_INST_2x(vsha512msg2, Vsha512msg2, Vec, Vec) + ASMJIT_INST_3x(vsha512rnds2, Vsha512rnds2, Vec, Vec, Vec) + //! \} + + //! \name AVX+SM3 Instructions + //! \{ + + ASMJIT_INST_3x(vsm3msg1, Vsm3msg1, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm3msg1, Vsm3msg1, Vec, Vec, Mem) + ASMJIT_INST_3x(vsm3msg2, Vsm3msg2, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm3msg2, Vsm3msg2, Vec, Vec, Mem) + ASMJIT_INST_4x(vsm3rnds2, Vsm3rnds2, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vsm3rnds2, Vsm3rnds2, Vec, Vec, Mem, Imm) + + //! \} + + //! \name AVX+SM4 Instructions + //! \{ + + ASMJIT_INST_3x(vsm4key4, Vsm4key4, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm4key4, Vsm4key4, Vec, Vec, Mem) + ASMJIT_INST_3x(vsm4rnds4, Vsm4rnds4, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm4rnds4, Vsm4rnds4, Vec, Vec, Mem) + + //! \} + + //! \name AVX512_FP16 Instructions + //! \{ + + ASMJIT_INST_3x(vaddph, Vaddph, Vec, Vec, Vec) + ASMJIT_INST_3x(vaddph, Vaddph, Vec, Vec, Mem) + ASMJIT_INST_3x(vaddsh, Vaddsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vaddsh, Vaddsh, Vec, Vec, Mem) + ASMJIT_INST_4x(vcmpph, Vcmpph, KReg, Vec, Vec, Imm) + ASMJIT_INST_4x(vcmpph, Vcmpph, KReg, Vec, Mem, Imm) + ASMJIT_INST_4x(vcmpsh, Vcmpsh, KReg, Vec, Vec, Imm) + ASMJIT_INST_4x(vcmpsh, Vcmpsh, KReg, Vec, Mem, Imm) + ASMJIT_INST_2x(vcomish, Vcomish, Vec, Vec) + ASMJIT_INST_2x(vcomish, Vcomish, Vec, Mem) + ASMJIT_INST_2x(vcvtdq2ph, Vcvtdq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtdq2ph, Vcvtdq2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtpd2ph, Vcvtpd2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtpd2ph, Vcvtpd2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtph2dq, Vcvtph2dq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2dq, Vcvtph2dq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2pd, Vcvtph2pd, Vec, Vec) + ASMJIT_INST_2x(vcvtph2pd, Vcvtph2pd, Vec, Mem) + ASMJIT_INST_2x(vcvtph2psx, Vcvtph2psx, Vec, Vec) + ASMJIT_INST_2x(vcvtph2psx, Vcvtph2psx, Vec, Mem) + ASMJIT_INST_2x(vcvtph2qq, Vcvtph2qq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2qq, Vcvtph2qq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2udq, Vcvtph2udq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2udq, Vcvtph2udq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2uqq, Vcvtph2uqq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2uqq, Vcvtph2uqq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2uw, Vcvtph2uw, Vec, Vec) + ASMJIT_INST_2x(vcvtph2uw, Vcvtph2uw, Vec, Mem) + ASMJIT_INST_2x(vcvtph2w, Vcvtph2w, Vec, Vec) + ASMJIT_INST_2x(vcvtph2w, Vcvtph2w, Vec, Mem) + ASMJIT_INST_2x(vcvtps2phx, Vcvtps2phx, Vec, Vec) + ASMJIT_INST_2x(vcvtps2phx, Vcvtps2phx, Vec, Mem) + ASMJIT_INST_2x(vcvtqq2ph, Vcvtqq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtqq2ph, Vcvtqq2ph, Vec, Mem) + ASMJIT_INST_3x(vcvtsd2sh, Vcvtsd2sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtsd2sh, Vcvtsd2sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vcvtsh2sd, Vcvtsh2sd, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtsh2sd, Vcvtsh2sd, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvtsh2si, Vcvtsh2si, Gp, Vec) + ASMJIT_INST_2x(vcvtsh2si, Vcvtsh2si, Gp, Mem) + ASMJIT_INST_3x(vcvtsh2ss, Vcvtsh2ss, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtsh2ss, Vcvtsh2ss, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvtsh2usi, Vcvtsh2usi, Gp, Vec) + ASMJIT_INST_2x(vcvtsh2usi, Vcvtsh2usi, Gp, Mem) + ASMJIT_INST_3x(vcvtsi2sh, Vcvtsi2sh, Vec, Vec, Gp) + ASMJIT_INST_3x(vcvtsi2sh, Vcvtsi2sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vcvtss2sh, Vcvtss2sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtss2sh, Vcvtss2sh, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvttph2dq, Vcvttph2dq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2dq, Vcvttph2dq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2qq, Vcvttph2qq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2qq, Vcvttph2qq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2udq, Vcvttph2udq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2udq, Vcvttph2udq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2uqq, Vcvttph2uqq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2uqq, Vcvttph2uqq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2uw, Vcvttph2uw, Vec, Vec) + ASMJIT_INST_2x(vcvttph2uw, Vcvttph2uw, Vec, Mem) + ASMJIT_INST_2x(vcvttph2w, Vcvttph2w, Vec, Vec) + ASMJIT_INST_2x(vcvttph2w, Vcvttph2w, Vec, Mem) + ASMJIT_INST_2x(vcvttsh2si, Vcvttsh2si, Gp, Vec) + ASMJIT_INST_2x(vcvttsh2si, Vcvttsh2si, Gp, Mem) + ASMJIT_INST_2x(vcvttsh2usi, Vcvttsh2usi, Gp, Vec) + ASMJIT_INST_2x(vcvttsh2usi, Vcvttsh2usi, Gp, Mem) + ASMJIT_INST_2x(vcvtudq2ph, Vcvtudq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtudq2ph, Vcvtudq2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtuqq2ph, Vcvtuqq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtuqq2ph, Vcvtuqq2ph, Vec, Mem) + ASMJIT_INST_3x(vcvtusi2sh, Vcvtusi2sh, Vec, Vec, Gp) + ASMJIT_INST_3x(vcvtusi2sh, Vcvtusi2sh, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvtuw2ph, Vcvtuw2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtuw2ph, Vcvtuw2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtw2ph, Vcvtw2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtw2ph, Vcvtw2ph, Vec, Mem) + ASMJIT_INST_3x(vdivph, Vdivph, Vec, Vec, Vec) + ASMJIT_INST_3x(vdivph, Vdivph, Vec, Vec, Mem) + ASMJIT_INST_3x(vdivsh, Vdivsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vdivsh, Vdivsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmaddcph, Vfcmaddcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmaddcph, Vfcmaddcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmaddcsh, Vfcmaddcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmaddcsh, Vfcmaddcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmulcph, Vfcmulcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmulcph, Vfcmulcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmulcsh, Vfcmulcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmulcsh, Vfcmulcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd132ph, Vfmadd132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd132ph, Vfmadd132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd132sh, Vfmadd132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd132sh, Vfmadd132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd213ph, Vfmadd213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd213ph, Vfmadd213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd213sh, Vfmadd213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd213sh, Vfmadd213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd231ph, Vfmadd231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd231ph, Vfmadd231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd231sh, Vfmadd231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd231sh, Vfmadd231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddcph, Vfmaddcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddcph, Vfmaddcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddcsh, Vfmaddcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddcsh, Vfmaddcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddsub132ph, Vfmaddsub132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddsub132ph, Vfmaddsub132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddsub213ph, Vfmaddsub213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddsub213ph, Vfmaddsub213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddsub231ph, Vfmaddsub231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddsub231ph, Vfmaddsub231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub132ph, Vfmsub132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub132ph, Vfmsub132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub132sh, Vfmsub132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub132sh, Vfmsub132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub213ph, Vfmsub213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub213ph, Vfmsub213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub213sh, Vfmsub213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub213sh, Vfmsub213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub231ph, Vfmsub231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub231ph, Vfmsub231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub231sh, Vfmsub231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub231sh, Vfmsub231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsubadd132ph, Vfmsubadd132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsubadd132ph, Vfmsubadd132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsubadd213ph, Vfmsubadd213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsubadd213ph, Vfmsubadd213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsubadd231ph, Vfmsubadd231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsubadd231ph, Vfmsubadd231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmulcph, Vfmulcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmulcph, Vfmulcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmulcsh, Vfmulcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmulcsh, Vfmulcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd132ph, Vfnmadd132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd132ph, Vfnmadd132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd132sh, Vfnmadd132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd132sh, Vfnmadd132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd213ph, Vfnmadd213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd213ph, Vfnmadd213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd213sh, Vfnmadd213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd213sh, Vfnmadd213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd231ph, Vfnmadd231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd231ph, Vfnmadd231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd231sh, Vfnmadd231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd231sh, Vfnmadd231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub132ph, Vfnmsub132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub132ph, Vfnmsub132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub132sh, Vfnmsub132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub132sh, Vfnmsub132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub213ph, Vfnmsub213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub213ph, Vfnmsub213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub213sh, Vfnmsub213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub213sh, Vfnmsub213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub231ph, Vfnmsub231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub231ph, Vfnmsub231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub231sh, Vfnmsub231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub231sh, Vfnmsub231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfpclassph, Vfpclassph, KReg, Vec, Imm) + ASMJIT_INST_3x(vfpclassph, Vfpclassph, KReg, Mem, Imm) + ASMJIT_INST_3x(vfpclasssh, Vfpclasssh, KReg, Vec, Imm) + ASMJIT_INST_3x(vfpclasssh, Vfpclasssh, KReg, Mem, Imm) + ASMJIT_INST_2x(vgetexpph, Vgetexpph, Vec, Vec) + ASMJIT_INST_2x(vgetexpph, Vgetexpph, Vec, Mem) + ASMJIT_INST_3x(vgetexpsh, Vgetexpsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vgetexpsh, Vgetexpsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vgetmantph, Vgetmantph, Vec, Vec, Imm) + ASMJIT_INST_3x(vgetmantph, Vgetmantph, Vec, Mem, Imm) + ASMJIT_INST_4x(vgetmantsh, Vgetmantsh, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vgetmantsh, Vgetmantsh, Vec, Vec, Mem, Imm) + ASMJIT_INST_3x(vmaxph, Vmaxph, Vec, Vec, Vec) + ASMJIT_INST_3x(vmaxph, Vmaxph, Vec, Vec, Mem) + ASMJIT_INST_3x(vmaxsh, Vmaxsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vmaxsh, Vmaxsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vminph, Vminph, Vec, Vec, Vec) + ASMJIT_INST_3x(vminph, Vminph, Vec, Vec, Mem) + ASMJIT_INST_3x(vminsh, Vminsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vminsh, Vminsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vmovsh, Vmovsh, Mem, Xmm) + ASMJIT_INST_2x(vmovsh, Vmovsh, Xmm, Mem) + ASMJIT_INST_3x(vmovsh, Vmovsh, Xmm, Xmm, Xmm) + ASMJIT_INST_2x(vmovw, Vmovw, Gp, Xmm) + ASMJIT_INST_2x(vmovw, Vmovw, Mem, Xmm) + ASMJIT_INST_2x(vmovw, Vmovw, Xmm, Gp) + ASMJIT_INST_2x(vmovw, Vmovw, Xmm, Mem) + ASMJIT_INST_3x(vmulph, Vmulph, Vec, Vec, Vec) + ASMJIT_INST_3x(vmulph, Vmulph, Vec, Vec, Mem) + ASMJIT_INST_3x(vmulsh, Vmulsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vmulsh, Vmulsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vrcpph, Vrcpph, Vec, Vec) + ASMJIT_INST_2x(vrcpph, Vrcpph, Vec, Mem) + ASMJIT_INST_3x(vrcpsh, Vrcpsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vrcpsh, Vrcpsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vreduceph, Vreduceph, Vec, Vec, Imm) + ASMJIT_INST_3x(vreduceph, Vreduceph, Vec, Mem, Imm) + ASMJIT_INST_4x(vreducesh, Vreducesh, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vreducesh, Vreducesh, Vec, Vec, Mem, Imm) + ASMJIT_INST_3x(vrndscaleph, Vrndscaleph, Vec, Vec, Imm) + ASMJIT_INST_3x(vrndscaleph, Vrndscaleph, Vec, Mem, Imm) + ASMJIT_INST_4x(vrndscalesh, Vrndscalesh, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vrndscalesh, Vrndscalesh, Vec, Vec, Mem, Imm) + ASMJIT_INST_2x(vrsqrtph, Vrsqrtph, Vec, Vec) + ASMJIT_INST_2x(vrsqrtph, Vrsqrtph, Vec, Mem) + ASMJIT_INST_3x(vrsqrtsh, Vrsqrtsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vrsqrtsh, Vrsqrtsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vscalefph, Vscalefph, Vec, Vec, Vec) + ASMJIT_INST_3x(vscalefph, Vscalefph, Vec, Vec, Mem) + ASMJIT_INST_3x(vscalefsh, Vscalefsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vscalefsh, Vscalefsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vsqrtph, Vsqrtph, Vec, Vec) + ASMJIT_INST_2x(vsqrtph, Vsqrtph, Vec, Mem) + ASMJIT_INST_3x(vsqrtsh, Vsqrtsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vsqrtsh, Vsqrtsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vsubph, Vsubph, Vec, Vec, Vec) + ASMJIT_INST_3x(vsubph, Vsubph, Vec, Vec, Mem) + ASMJIT_INST_3x(vsubsh, Vsubsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vsubsh, Vsubsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vucomish, Vucomish, Vec, Vec) + ASMJIT_INST_2x(vucomish, Vucomish, Vec, Mem) + + //! \} + + //! \name AMX_TILE Instructions + //! \{ + + ASMJIT_INST_1x(ldtilecfg, Ldtilecfg, Mem) + ASMJIT_INST_1x(sttilecfg, Sttilecfg, Mem) + ASMJIT_INST_2x(tileloadd, Tileloadd, Tmm, Mem) + ASMJIT_INST_2x(tileloaddt1, Tileloaddt1, Tmm, Mem) + ASMJIT_INST_0x(tilerelease, Tilerelease) + ASMJIT_INST_2x(tilestored, Tilestored, Mem, Tmm) + ASMJIT_INST_1x(tilezero, Tilezero, Tmm) + + //! \} + + //! \name AMX_BF16 Instructions + //! \{ + + ASMJIT_INST_3x(tdpbf16ps, Tdpbf16ps, Tmm, Tmm, Tmm) + + //! \} + + //! \name AMX_COMPLEX Instructions + //! \{ + + ASMJIT_INST_3x(tcmmimfp16ps, Tcmmimfp16ps, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tcmmrlfp16ps, Tcmmrlfp16ps, Tmm, Tmm, Tmm) + + //! \} + + //! \name AMX_FP16 Instructions + //! \{ + + ASMJIT_INST_3x(tdpfp16ps, Tdpfp16ps, Tmm, Tmm, Tmm) + + //! \} + + //! \name AMX_INT8 Instructions + //! \{ + + ASMJIT_INST_3x(tdpbssd, Tdpbssd, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tdpbsud, Tdpbsud, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tdpbusd, Tdpbusd, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tdpbuud, Tdpbuud, Tmm, Tmm, Tmm) + + //! \} +}; + +//! Emitter (X86 - implicit). +template +struct EmitterImplicitT : public EmitterExplicitT { + //! \cond + using EmitterExplicitT::_emitter; + //! \endcond + + //! \name Prefix Options + //! \{ + + //! Use REP/REPE prefix. + inline This& rep() noexcept { return EmitterExplicitT::_addInstOptions(InstOptions::kX86_Rep); } + //! Use REP/REPE prefix. + inline This& repe() noexcept { return rep(); } + //! Use REP/REPE prefix. + inline This& repz() noexcept { return rep(); } + + //! Use REPNE prefix. + inline This& repne() noexcept { return EmitterExplicitT::_addInstOptions(InstOptions::kX86_Repne); } + //! Use REPNE prefix. + inline This& repnz() noexcept { return repne(); } + + //! \} + + //! \name Core Instructions + //! \{ + + //! \cond + using EmitterExplicitT::cbw; + using EmitterExplicitT::cdq; + using EmitterExplicitT::cdqe; + using EmitterExplicitT::cqo; + using EmitterExplicitT::cwd; + using EmitterExplicitT::cwde; + using EmitterExplicitT::cmpsd; + using EmitterExplicitT::cmpxchg; + using EmitterExplicitT::cmpxchg8b; + using EmitterExplicitT::cmpxchg16b; + using EmitterExplicitT::div; + using EmitterExplicitT::idiv; + using EmitterExplicitT::imul; + using EmitterExplicitT::jecxz; + using EmitterExplicitT::loop; + using EmitterExplicitT::loope; + using EmitterExplicitT::loopne; + using EmitterExplicitT::mul; + //! \endcond + + ASMJIT_INST_0x(cbw, Cbw) // ANY [IMPLICIT] AX <- Sign Extend AL + ASMJIT_INST_0x(cdq, Cdq) // ANY [IMPLICIT] EDX:EAX <- Sign Extend EAX + ASMJIT_INST_0x(cdqe, Cdqe) // X64 [IMPLICIT] RAX <- Sign Extend EAX + ASMJIT_INST_2x(cmpxchg, Cmpxchg, Gp, Gp) // I486 [IMPLICIT] + ASMJIT_INST_2x(cmpxchg, Cmpxchg, Mem, Gp) // I486 [IMPLICIT] + ASMJIT_INST_1x(cmpxchg16b, Cmpxchg16b, Mem) // CMPXCHG8B [IMPLICIT] m == RDX:RAX ? m <- RCX:RBX + ASMJIT_INST_1x(cmpxchg8b, Cmpxchg8b, Mem) // CMPXCHG16B[IMPLICIT] m == EDX:EAX ? m <- ECX:EBX + ASMJIT_INST_0x(cqo, Cqo) // X64 [IMPLICIT] RDX:RAX <- Sign Extend RAX + ASMJIT_INST_0x(cwd, Cwd) // ANY [IMPLICIT] DX:AX <- Sign Extend AX + ASMJIT_INST_0x(cwde, Cwde) // ANY [IMPLICIT] EAX <- Sign Extend AX + ASMJIT_INST_1x(div, Div, Gp) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / r8} {xDX[Rem]:xAX[Quot] <- DX:AX / r16|r32|r64} + ASMJIT_INST_1x(div, Div, Mem) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / m8} {xDX[Rem]:xAX[Quot] <- DX:AX / m16|m32|m64} + ASMJIT_INST_1x(idiv, Idiv, Gp) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / r8} {xDX[Rem]:xAX[Quot] <- DX:AX / r16|r32|r64} + ASMJIT_INST_1x(idiv, Idiv, Mem) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / m8} {xDX[Rem]:xAX[Quot] <- DX:AX / m16|m32|m64} + ASMJIT_INST_1x(imul, Imul, Gp) // ANY [IMPLICIT] {AX <- AL * r8} {xAX:xDX <- xAX * r16|r32|r64} + ASMJIT_INST_1x(imul, Imul, Mem) // ANY [IMPLICIT] {AX <- AL * m8} {xAX:xDX <- xAX * m16|m32|m64} + ASMJIT_INST_0x(iret, Iret) // ANY [IMPLICIT] + ASMJIT_INST_0x(iretd, Iretd) // ANY [IMPLICIT] + ASMJIT_INST_0x(iretq, Iretq) // X64 [IMPLICIT] + ASMJIT_INST_1x(jecxz, Jecxz, Label) // ANY [IMPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_1x(jecxz, Jecxz, Imm) // ANY [IMPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_1x(loop, Loop, Label) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_1x(loop, Loop, Imm) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_1x(loope, Loope, Label) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_1x(loope, Loope, Imm) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_1x(loopne, Loopne, Label) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_1x(loopne, Loopne, Imm) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_1x(mul, Mul, Gp) // ANY [IMPLICIT] {AX <- AL * r8} {xDX:xAX <- xAX * r16|r32|r64} + ASMJIT_INST_1x(mul, Mul, Mem) // ANY [IMPLICIT] {AX <- AL * m8} {xDX:xAX <- xAX * m16|m32|m64} + ASMJIT_INST_0x(ret, Ret) + ASMJIT_INST_1x(ret, Ret, Imm) + ASMJIT_INST_0x(retf, Retf) + ASMJIT_INST_1x(retf, Retf, Imm) + ASMJIT_INST_0x(xlatb, Xlatb) // ANY [IMPLICIT] + + //! \} + + //! \name String Instruction Aliases + //! \{ + + //! \cond + using EmitterExplicitT::movsd; + //! \endcond + + inline Error cmpsb() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 1), EmitterExplicitT::ptr_zdi(0, 1)); } + inline Error cmpsd() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 4), EmitterExplicitT::ptr_zdi(0, 4)); } + inline Error cmpsq() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 8), EmitterExplicitT::ptr_zdi(0, 8)); } + inline Error cmpsw() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 2), EmitterExplicitT::ptr_zdi(0, 2)); } + + inline Error lodsb() { return _emitter()->emit(Inst::kIdLods, al , EmitterExplicitT::ptr_zsi(0, 1)); } + inline Error lodsd() { return _emitter()->emit(Inst::kIdLods, eax, EmitterExplicitT::ptr_zsi(0, 4)); } + inline Error lodsq() { return _emitter()->emit(Inst::kIdLods, rax, EmitterExplicitT::ptr_zsi(0, 8)); } + inline Error lodsw() { return _emitter()->emit(Inst::kIdLods, ax , EmitterExplicitT::ptr_zsi(0, 2)); } + + inline Error movsb() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 1), EmitterExplicitT::ptr_zsi(0, 1)); } + inline Error movsd() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 4), EmitterExplicitT::ptr_zsi(0, 4)); } + inline Error movsq() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 8), EmitterExplicitT::ptr_zsi(0, 8)); } + inline Error movsw() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 2), EmitterExplicitT::ptr_zsi(0, 2)); } + + inline Error scasb() { return _emitter()->emit(Inst::kIdScas, al , EmitterExplicitT::ptr_zdi(0, 1)); } + inline Error scasd() { return _emitter()->emit(Inst::kIdScas, eax, EmitterExplicitT::ptr_zdi(0, 4)); } + inline Error scasq() { return _emitter()->emit(Inst::kIdScas, rax, EmitterExplicitT::ptr_zdi(0, 8)); } + inline Error scasw() { return _emitter()->emit(Inst::kIdScas, ax , EmitterExplicitT::ptr_zdi(0, 2)); } + + inline Error stosb() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 1), al ); } + inline Error stosd() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 4), eax); } + inline Error stosq() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 8), rax); } + inline Error stosw() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 2), ax ); } + + //! \} + + //! \name Deprecated 32-bit Instructions + //! \{ + + //! \cond + using EmitterExplicitT::aaa; + using EmitterExplicitT::aad; + using EmitterExplicitT::aam; + using EmitterExplicitT::aas; + using EmitterExplicitT::daa; + using EmitterExplicitT::das; + //! \endcond + + ASMJIT_INST_0x(aaa, Aaa) // X86 [IMPLICIT] + ASMJIT_INST_1x(aad, Aad, Imm) // X86 [IMPLICIT] + ASMJIT_INST_1x(aam, Aam, Imm) // X86 [IMPLICIT] + ASMJIT_INST_0x(aas, Aas) // X86 [IMPLICIT] + ASMJIT_INST_0x(daa, Daa) // X86 [IMPLICIT] + ASMJIT_INST_0x(das, Das) // X86 [IMPLICIT] + + //! \} + + //! \name LAHF/SAHF Instructions + //! \{ + + //! \cond + using EmitterExplicitT::lahf; + using EmitterExplicitT::sahf; + //! \endcond + + ASMJIT_INST_0x(lahf, Lahf) // LAHFSAHF [IMPLICIT] AH <- EFL + ASMJIT_INST_0x(sahf, Sahf) // LAHFSAHF [IMPLICIT] EFL <- AH + + //! \} + + //! \name CPUID Instruction + //! \{ + + //! \cond + using EmitterExplicitT::cpuid; + //! \endcond + + ASMJIT_INST_0x(cpuid, Cpuid) // I486 [IMPLICIT] EAX:EBX:ECX:EDX <- CPUID[EAX:ECX] + + //! \} + + //! \name CacheLine Instructions + //! \{ + + //! \cond + using EmitterExplicitT::clzero; + //! \endcond + + ASMJIT_INST_0x(clzero, Clzero) // CLZERO [IMPLICIT] + + //! \} + + //! \name RDPRU/RDPKRU Instructions + //! \{ + + //! \cond + using EmitterExplicitT::rdpru; + using EmitterExplicitT::rdpkru; + //! \endcond + + ASMJIT_INST_0x(rdpru, Rdpru) // RDPRU [IMPLICIT] EDX:EAX <- PRU[ECX] + ASMJIT_INST_0x(rdpkru, Rdpkru) // RDPKRU [IMPLICIT] EDX:EAX <- PKRU[ECX] + + //! \} + + //! \name RDTSC/RDTSCP Instructions + //! \{ + + //! \cond + using EmitterExplicitT::rdtsc; + using EmitterExplicitT::rdtscp; + //! \endcond + + ASMJIT_INST_0x(rdtsc, Rdtsc) // RDTSC [IMPLICIT] EDX:EAX <- CNT + ASMJIT_INST_0x(rdtscp, Rdtscp) // RDTSCP [IMPLICIT] EDX:EAX:EXC <- CNT + + //! \} + + //! \name BMI2 Instructions + //! \{ + + //! \cond + using EmitterExplicitT::mulx; + //! \endcond + + ASMJIT_INST_3x(mulx, Mulx, Gp, Gp, Gp) // BMI2 [IMPLICIT] + ASMJIT_INST_3x(mulx, Mulx, Gp, Gp, Mem) // BMI2 [IMPLICIT] + + //! \} + + //! \name XSAVE Instructions + //! \{ + + //! \cond + using EmitterExplicitT::xgetbv; + using EmitterExplicitT::xrstor; + using EmitterExplicitT::xrstor64; + using EmitterExplicitT::xrstors; + using EmitterExplicitT::xrstors64; + using EmitterExplicitT::xsave; + using EmitterExplicitT::xsave64; + using EmitterExplicitT::xsavec; + using EmitterExplicitT::xsavec64; + using EmitterExplicitT::xsaveopt; + using EmitterExplicitT::xsaveopt64; + using EmitterExplicitT::xsaves; + using EmitterExplicitT::xsaves64; + //! \endcond + + ASMJIT_INST_0x(xgetbv, Xgetbv) // XSAVE [IMPLICIT] EDX:EAX <- XCR[ECX] + ASMJIT_INST_1x(xrstor, Xrstor, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xrstor64, Xrstor64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xrstors, Xrstors, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xrstors64, Xrstors64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsave, Xsave, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsave64, Xsave64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsavec, Xsavec, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsavec64, Xsavec64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsaveopt, Xsaveopt, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsaveopt64, Xsaveopt64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsaves, Xsaves, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsaves64, Xsaves64, Mem) // XSAVE+X64 [IMPLICIT] + + //! \} + + //! \name SYSCALL/SYSENTER Instructions + //! \{ + + ASMJIT_INST_0x(syscall, Syscall) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysenter, Sysenter) // X64 [IMPLICIT] + + //! \} + + //! \name HRESET Instructions + //! \{ + + //! \cond + using EmitterExplicitT::hreset; + //! \endcond + + ASMJIT_INST_1x(hreset, Hreset, Imm) // HRESET [IMPLICIT] + + //! \} + + //! \name SEAM Instructions + //! \{ + + ASMJIT_INST_0x(seamcall, Seamcall) + ASMJIT_INST_0x(seamops, Seamops) + ASMJIT_INST_0x(seamret, Seamret) + ASMJIT_INST_0x(tdcall, Tdcall) + + //! \} + + //! \name Privileged Instructions + //! \{ + + //! \cond + using EmitterExplicitT::rdmsr; + using EmitterExplicitT::rdpmc; + using EmitterExplicitT::wrmsr; + using EmitterExplicitT::xsetbv; + //! \endcond + + ASMJIT_INST_0x(pconfig, Pconfig) // PCONFIG [IMPLICIT] + ASMJIT_INST_0x(rdmsr, Rdmsr) // ANY [IMPLICIT] + ASMJIT_INST_0x(rdpmc, Rdpmc) // ANY [IMPLICIT] + ASMJIT_INST_0x(sysexit, Sysexit) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysexitq, Sysexitq) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysret, Sysret) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysretq, Sysretq) // X64 [IMPLICIT] + ASMJIT_INST_0x(wrmsr, Wrmsr) // ANY [IMPLICIT] + ASMJIT_INST_0x(xsetbv, Xsetbv) // XSAVE [IMPLICIT] XCR[ECX] <- EDX:EAX + + //! \} + + //! \name Monitor & MWait Instructions + //! \{ + + //! \cond + using EmitterExplicitT::monitor; + using EmitterExplicitT::monitorx; + using EmitterExplicitT::mwait; + using EmitterExplicitT::mwaitx; + //! \endcond + + ASMJIT_INST_0x(monitor, Monitor) + ASMJIT_INST_0x(monitorx, Monitorx) + ASMJIT_INST_0x(mwait, Mwait) + ASMJIT_INST_0x(mwaitx, Mwaitx) + + //! \} + + //! \name WAITPKG Instructions + //! \{ + + //! \cond + using EmitterExplicitT::tpause; + using EmitterExplicitT::umwait; + //! \endcond + + ASMJIT_INST_1x(tpause, Tpause, Gp) + ASMJIT_INST_1x(umwait, Umwait, Gp) + + //! \} + + //! \name MMX & SSE Instructions + //! \{ + + //! \cond + using EmitterExplicitT::blendvpd; + using EmitterExplicitT::blendvps; + using EmitterExplicitT::maskmovq; + using EmitterExplicitT::maskmovdqu; + using EmitterExplicitT::pblendvb; + using EmitterExplicitT::pcmpestri; + using EmitterExplicitT::pcmpestrm; + using EmitterExplicitT::pcmpistri; + using EmitterExplicitT::pcmpistrm; + //! \endcond + + ASMJIT_INST_2x(blendvpd, Blendvpd, Xmm, Xmm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(blendvpd, Blendvpd, Xmm, Mem) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(blendvps, Blendvps, Xmm, Xmm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(blendvps, Blendvps, Xmm, Mem) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(pblendvb, Pblendvb, Xmm, Xmm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(pblendvb, Pblendvb, Xmm, Mem) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(maskmovq, Maskmovq, Mm, Mm) // SSE [IMPLICIT] + ASMJIT_INST_2x(maskmovdqu, Maskmovdqu, Xmm, Xmm) // SSE2 [IMPLICIT] + ASMJIT_INST_3x(pcmpestri, Pcmpestri, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpestri, Pcmpestri, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpestrm, Pcmpestrm, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpestrm, Pcmpestrm, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistri, Pcmpistri, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistri, Pcmpistri, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistrm, Pcmpistrm, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistrm, Pcmpistrm, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + + //! \} + + //! \name SHA Instructions + //! \{ + + //! \cond + using EmitterExplicitT::sha256rnds2; + //! \endcond + + ASMJIT_INST_2x(sha256rnds2, Sha256rnds2, Xmm, Xmm) // SHA [IMPLICIT] + ASMJIT_INST_2x(sha256rnds2, Sha256rnds2, Xmm, Mem) // SHA [IMPLICIT] + + //! \} + + //! \name AVX, FMA, and AVX512 Instructions + //! \{ + + //! \cond + using EmitterExplicitT::vmaskmovdqu; + using EmitterExplicitT::vpcmpestri; + using EmitterExplicitT::vpcmpestrm; + using EmitterExplicitT::vpcmpistri; + using EmitterExplicitT::vpcmpistrm; + //! \endcond + + ASMJIT_INST_2x(vmaskmovdqu, Vmaskmovdqu, Xmm, Xmm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestri, Vpcmpestri, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestri, Vpcmpestri, Xmm, Mem, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestrm, Vpcmpestrm, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestrm, Vpcmpestrm, Xmm, Mem, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistri, Vpcmpistri, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistri, Vpcmpistri, Xmm, Mem, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistrm, Vpcmpistrm, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistrm, Vpcmpistrm, Xmm, Mem, Imm) // AVX [IMPLICIT] + + //! \} +}; + +//! Emitter (X86). +//! +//! \note This class cannot be instantiated, you can only cast to it and use it as emitter that emits to either +//! `x86::Assembler`, `x86::Builder`, or `x86::Compiler` (use with caution with `x86::Compiler` as it requires +//! virtual registers). +class Emitter : public BaseEmitter, public EmitterImplicitT { + ASMJIT_NONCONSTRUCTIBLE(Emitter) +}; + +//! \} + +#undef ASMJIT_INST_0x +#undef ASMJIT_INST_1x +#undef ASMJIT_INST_1c +#undef ASMJIT_INST_2x +#undef ASMJIT_INST_2c +#undef ASMJIT_INST_3x +#undef ASMJIT_INST_4x +#undef ASMJIT_INST_5x +#undef ASMJIT_INST_6x + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86EMITTER_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86globals.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86globals.h new file mode 100644 index 0000000000000000000000000000000000000000..21ed41e87d13aeb0a20367463ed6e97cd53352ae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86globals.h @@ -0,0 +1,2234 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86GLOBALS_H_INCLUDED +#define ASMJIT_X86_X86GLOBALS_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/inst.h" + +//! \namespace asmjit::x86 +//! \ingroup asmjit_x86 +//! +//! X86/X64 API. + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! Condition code. +enum class CondCode : uint8_t { + kO = 0x00u, //!< OF==1 + kNO = 0x01u, //!< OF==0 + kC = 0x02u, //!< CF==1 + kB = 0x02u, //!< CF==1 (unsigned < ) + kNAE = 0x02u, //!< CF==1 (unsigned < ) + kNC = 0x03u, //!< CF==0 + kAE = 0x03u, //!< CF==0 (unsigned >=) + kNB = 0x03u, //!< CF==0 (unsigned >=) + kE = 0x04u, //!< ZF==1 (any_sign ==) + kZ = 0x04u, //!< ZF==1 (any_sign ==) + kNE = 0x05u, //!< ZF==0 (any_sign !=) + kNZ = 0x05u, //!< ZF==0 (any_sign !=) + kBE = 0x06u, //!< CF==1 | ZF==1 (unsigned <=) + kNA = 0x06u, //!< CF==1 | ZF==1 (unsigned <=) + kA = 0x07u, //!< CF==0 & ZF==0 (unsigned > ) + kNBE = 0x07u, //!< CF==0 & ZF==0 (unsigned > ) + kS = 0x08u, //!< SF==1 (is negative) + kNS = 0x09u, //!< SF==0 (is positive or zero) + kP = 0x0Au, //!< PF==1 + kPE = 0x0Au, //!< PF==1 + kPO = 0x0Bu, //!< PF==0 + kNP = 0x0Bu, //!< PF==0 + kL = 0x0Cu, //!< SF!=OF (signed < ) + kNGE = 0x0Cu, //!< SF!=OF (signed < ) + kGE = 0x0Du, //!< SF==OF (signed >=) + kNL = 0x0Du, //!< SF==OF (signed >=) + kLE = 0x0Eu, //!< ZF==1 | SF!=OF (signed <=) + kNG = 0x0Eu, //!< ZF==1 | SF!=OF (signed <=) + kG = 0x0Fu, //!< ZF==0 & SF==OF (signed > ) + kNLE = 0x0Fu, //!< ZF==0 & SF==OF (signed > ) + + kZero = kZ, //!< Zero flag. + kNotZero = kNZ, //!< Not zero. + + kEqual = kE, //!< `a == b` (equal). + kNotEqual = kNE, //!< `a != b` (not equal). + + kCarry = kC, //!< Carry flag. + kNotCarry = kNC, //!< Not carry. + + kSign = kS, //!< Sign flag. + kNotSign = kNS, //!< Not sign. + + kNegative = kS, //!< Sign flag. + kPositive = kNS, //!< Not sign. + + kOverflow = kO, //!< Overflow (signed). + kNotOverflow = kNO, //!< Not overflow (signed). + + kSignedLT = kL, //!< `a < b` (signed). + kSignedLE = kLE, //!< `a <= b` (signed). + kSignedGT = kG, //!< `a > b` (signed). + kSignedGE = kGE, //!< `a >= b` (signed). + + kUnsignedLT = kB, //!< `a < b` (unsigned). + kUnsignedLE = kBE, //!< `a <= b` (unsigned). + kUnsignedGT = kA, //!< `a > b` (unsigned). + kUnsignedGE = kAE, //!< `a >= b` (unsigned). + + kBTZero = kNC, //!< Tested bit is zero. + kBTNotZero = kC, //!< Tested bit is non-zero. + + kParityEven = kP, //!< Even parity flag. + kParityOdd = kPO, //!< Odd parity flag. + + kMaxValue = 0x0Fu +}; + +//! \cond +static constexpr CondCode _reverseCondTable[] = { + CondCode::kO, // O <- O + CondCode::kNO, // NO <- NO + CondCode::kA , // A <- B + CondCode::kBE, // BE <- AE + CondCode::kE, // E <- E + CondCode::kNE, // NE <- NE + CondCode::kAE, // AE <- BE + CondCode::kB , // B <- A + CondCode::kS, // S <- S + CondCode::kNS, // NS <- NS + CondCode::kPE, // PE <- PE + CondCode::kPO, // PO <- PO + CondCode::kG, // G <- L + CondCode::kLE, // LE <- GE + CondCode::kGE, // GE <- LE + CondCode::kL // L <- G +}; +//! \endcond + +//! Reverses a condition code (reverses the corresponding operands of a comparison). +static ASMJIT_INLINE_NODEBUG constexpr CondCode reverseCond(CondCode cond) noexcept { return _reverseCondTable[uint8_t(cond)]; } +//! Negates a condition code. +static ASMJIT_INLINE_NODEBUG constexpr CondCode negateCond(CondCode cond) noexcept { return CondCode(uint8_t(cond) ^ 1u); } + +//! Instruction. +//! +//! \note Only used to hold x86-specific instruction identifiers and some additional helper functions. +namespace Inst { + //! Instruction id. + enum Id : uint32_t { + // ${InstId:Begin} + kIdNone = 0, //!< Invalid instruction id. + kIdAaa, //!< Instruction 'aaa' (X86). + kIdAad, //!< Instruction 'aad' (X86). + kIdAadd, //!< Instruction 'aadd' {RAO_INT}. + kIdAam, //!< Instruction 'aam' (X86). + kIdAand, //!< Instruction 'aand' {RAO_INT}. + kIdAas, //!< Instruction 'aas' (X86). + kIdAdc, //!< Instruction 'adc'. + kIdAdcx, //!< Instruction 'adcx' {ADX}. + kIdAdd, //!< Instruction 'add'. + kIdAddpd, //!< Instruction 'addpd' {SSE2}. + kIdAddps, //!< Instruction 'addps' {SSE}. + kIdAddsd, //!< Instruction 'addsd' {SSE2}. + kIdAddss, //!< Instruction 'addss' {SSE}. + kIdAddsubpd, //!< Instruction 'addsubpd' {SSE3}. + kIdAddsubps, //!< Instruction 'addsubps' {SSE3}. + kIdAdox, //!< Instruction 'adox' {ADX}. + kIdAesdec, //!< Instruction 'aesdec' {AESNI}. + kIdAesdeclast, //!< Instruction 'aesdeclast' {AESNI}. + kIdAesenc, //!< Instruction 'aesenc' {AESNI}. + kIdAesenclast, //!< Instruction 'aesenclast' {AESNI}. + kIdAesimc, //!< Instruction 'aesimc' {AESNI}. + kIdAeskeygenassist, //!< Instruction 'aeskeygenassist' {AESNI}. + kIdAnd, //!< Instruction 'and'. + kIdAndn, //!< Instruction 'andn' {BMI}. + kIdAndnpd, //!< Instruction 'andnpd' {SSE2}. + kIdAndnps, //!< Instruction 'andnps' {SSE}. + kIdAndpd, //!< Instruction 'andpd' {SSE2}. + kIdAndps, //!< Instruction 'andps' {SSE}. + kIdAor, //!< Instruction 'aor' {RAO_INT}. + kIdArpl, //!< Instruction 'arpl' (X86). + kIdAxor, //!< Instruction 'axor' {RAO_INT}. + kIdBextr, //!< Instruction 'bextr' {BMI}. + kIdBlcfill, //!< Instruction 'blcfill' {TBM}. + kIdBlci, //!< Instruction 'blci' {TBM}. + kIdBlcic, //!< Instruction 'blcic' {TBM}. + kIdBlcmsk, //!< Instruction 'blcmsk' {TBM}. + kIdBlcs, //!< Instruction 'blcs' {TBM}. + kIdBlendpd, //!< Instruction 'blendpd' {SSE4_1}. + kIdBlendps, //!< Instruction 'blendps' {SSE4_1}. + kIdBlendvpd, //!< Instruction 'blendvpd' {SSE4_1}. + kIdBlendvps, //!< Instruction 'blendvps' {SSE4_1}. + kIdBlsfill, //!< Instruction 'blsfill' {TBM}. + kIdBlsi, //!< Instruction 'blsi' {BMI}. + kIdBlsic, //!< Instruction 'blsic' {TBM}. + kIdBlsmsk, //!< Instruction 'blsmsk' {BMI}. + kIdBlsr, //!< Instruction 'blsr' {BMI}. + kIdBndcl, //!< Instruction 'bndcl' {MPX}. + kIdBndcn, //!< Instruction 'bndcn' {MPX}. + kIdBndcu, //!< Instruction 'bndcu' {MPX}. + kIdBndldx, //!< Instruction 'bndldx' {MPX}. + kIdBndmk, //!< Instruction 'bndmk' {MPX}. + kIdBndmov, //!< Instruction 'bndmov' {MPX}. + kIdBndstx, //!< Instruction 'bndstx' {MPX}. + kIdBound, //!< Instruction 'bound' (X86). + kIdBsf, //!< Instruction 'bsf'. + kIdBsr, //!< Instruction 'bsr'. + kIdBswap, //!< Instruction 'bswap'. + kIdBt, //!< Instruction 'bt'. + kIdBtc, //!< Instruction 'btc'. + kIdBtr, //!< Instruction 'btr'. + kIdBts, //!< Instruction 'bts'. + kIdBzhi, //!< Instruction 'bzhi' {BMI2}. + kIdCall, //!< Instruction 'call'. + kIdCbw, //!< Instruction 'cbw'. + kIdCdq, //!< Instruction 'cdq'. + kIdCdqe, //!< Instruction 'cdqe' (X64). + kIdClac, //!< Instruction 'clac' {SMAP}. + kIdClc, //!< Instruction 'clc'. + kIdCld, //!< Instruction 'cld'. + kIdCldemote, //!< Instruction 'cldemote' {CLDEMOTE}. + kIdClflush, //!< Instruction 'clflush' {CLFLUSH}. + kIdClflushopt, //!< Instruction 'clflushopt' {CLFLUSHOPT}. + kIdClgi, //!< Instruction 'clgi' {SVM}. + kIdCli, //!< Instruction 'cli'. + kIdClrssbsy, //!< Instruction 'clrssbsy' {CET_SS}. + kIdClts, //!< Instruction 'clts'. + kIdClui, //!< Instruction 'clui' {UINTR} (X64). + kIdClwb, //!< Instruction 'clwb' {CLWB}. + kIdClzero, //!< Instruction 'clzero' {CLZERO}. + kIdCmc, //!< Instruction 'cmc'. + kIdCmova, //!< Instruction 'cmova' {CMOV}. + kIdCmovae, //!< Instruction 'cmovae' {CMOV}. + kIdCmovb, //!< Instruction 'cmovb' {CMOV}. + kIdCmovbe, //!< Instruction 'cmovbe' {CMOV}. + kIdCmovc, //!< Instruction 'cmovc' {CMOV}. + kIdCmove, //!< Instruction 'cmove' {CMOV}. + kIdCmovg, //!< Instruction 'cmovg' {CMOV}. + kIdCmovge, //!< Instruction 'cmovge' {CMOV}. + kIdCmovl, //!< Instruction 'cmovl' {CMOV}. + kIdCmovle, //!< Instruction 'cmovle' {CMOV}. + kIdCmovna, //!< Instruction 'cmovna' {CMOV}. + kIdCmovnae, //!< Instruction 'cmovnae' {CMOV}. + kIdCmovnb, //!< Instruction 'cmovnb' {CMOV}. + kIdCmovnbe, //!< Instruction 'cmovnbe' {CMOV}. + kIdCmovnc, //!< Instruction 'cmovnc' {CMOV}. + kIdCmovne, //!< Instruction 'cmovne' {CMOV}. + kIdCmovng, //!< Instruction 'cmovng' {CMOV}. + kIdCmovnge, //!< Instruction 'cmovnge' {CMOV}. + kIdCmovnl, //!< Instruction 'cmovnl' {CMOV}. + kIdCmovnle, //!< Instruction 'cmovnle' {CMOV}. + kIdCmovno, //!< Instruction 'cmovno' {CMOV}. + kIdCmovnp, //!< Instruction 'cmovnp' {CMOV}. + kIdCmovns, //!< Instruction 'cmovns' {CMOV}. + kIdCmovnz, //!< Instruction 'cmovnz' {CMOV}. + kIdCmovo, //!< Instruction 'cmovo' {CMOV}. + kIdCmovp, //!< Instruction 'cmovp' {CMOV}. + kIdCmovpe, //!< Instruction 'cmovpe' {CMOV}. + kIdCmovpo, //!< Instruction 'cmovpo' {CMOV}. + kIdCmovs, //!< Instruction 'cmovs' {CMOV}. + kIdCmovz, //!< Instruction 'cmovz' {CMOV}. + kIdCmp, //!< Instruction 'cmp'. + kIdCmpbexadd, //!< Instruction 'cmpbexadd' {CMPCCXADD}. + kIdCmpbxadd, //!< Instruction 'cmpbxadd' {CMPCCXADD}. + kIdCmplexadd, //!< Instruction 'cmplexadd' {CMPCCXADD}. + kIdCmplxadd, //!< Instruction 'cmplxadd' {CMPCCXADD}. + kIdCmpnbexadd, //!< Instruction 'cmpnbexadd' {CMPCCXADD}. + kIdCmpnbxadd, //!< Instruction 'cmpnbxadd' {CMPCCXADD}. + kIdCmpnlexadd, //!< Instruction 'cmpnlexadd' {CMPCCXADD}. + kIdCmpnlxadd, //!< Instruction 'cmpnlxadd' {CMPCCXADD}. + kIdCmpnoxadd, //!< Instruction 'cmpnoxadd' {CMPCCXADD}. + kIdCmpnpxadd, //!< Instruction 'cmpnpxadd' {CMPCCXADD}. + kIdCmpnsxadd, //!< Instruction 'cmpnsxadd' {CMPCCXADD}. + kIdCmpnzxadd, //!< Instruction 'cmpnzxadd' {CMPCCXADD}. + kIdCmpoxadd, //!< Instruction 'cmpoxadd' {CMPCCXADD}. + kIdCmppd, //!< Instruction 'cmppd' {SSE2}. + kIdCmpps, //!< Instruction 'cmpps' {SSE}. + kIdCmppxadd, //!< Instruction 'cmppxadd' {CMPCCXADD}. + kIdCmps, //!< Instruction 'cmps'. + kIdCmpsd, //!< Instruction 'cmpsd' {SSE2}. + kIdCmpss, //!< Instruction 'cmpss' {SSE}. + kIdCmpsxadd, //!< Instruction 'cmpsxadd' {CMPCCXADD}. + kIdCmpxchg, //!< Instruction 'cmpxchg' {I486}. + kIdCmpxchg16b, //!< Instruction 'cmpxchg16b' {CMPXCHG16B} (X64). + kIdCmpxchg8b, //!< Instruction 'cmpxchg8b' {CMPXCHG8B}. + kIdCmpzxadd, //!< Instruction 'cmpzxadd' {CMPCCXADD}. + kIdComisd, //!< Instruction 'comisd' {SSE2}. + kIdComiss, //!< Instruction 'comiss' {SSE}. + kIdCpuid, //!< Instruction 'cpuid' {I486}. + kIdCqo, //!< Instruction 'cqo' (X64). + kIdCrc32, //!< Instruction 'crc32' {SSE4_2}. + kIdCvtdq2pd, //!< Instruction 'cvtdq2pd' {SSE2}. + kIdCvtdq2ps, //!< Instruction 'cvtdq2ps' {SSE2}. + kIdCvtpd2dq, //!< Instruction 'cvtpd2dq' {SSE2}. + kIdCvtpd2pi, //!< Instruction 'cvtpd2pi' {SSE2}. + kIdCvtpd2ps, //!< Instruction 'cvtpd2ps' {SSE2}. + kIdCvtpi2pd, //!< Instruction 'cvtpi2pd' {SSE2}. + kIdCvtpi2ps, //!< Instruction 'cvtpi2ps' {SSE}. + kIdCvtps2dq, //!< Instruction 'cvtps2dq' {SSE2}. + kIdCvtps2pd, //!< Instruction 'cvtps2pd' {SSE2}. + kIdCvtps2pi, //!< Instruction 'cvtps2pi' {SSE}. + kIdCvtsd2si, //!< Instruction 'cvtsd2si' {SSE2}. + kIdCvtsd2ss, //!< Instruction 'cvtsd2ss' {SSE2}. + kIdCvtsi2sd, //!< Instruction 'cvtsi2sd' {SSE2}. + kIdCvtsi2ss, //!< Instruction 'cvtsi2ss' {SSE}. + kIdCvtss2sd, //!< Instruction 'cvtss2sd' {SSE2}. + kIdCvtss2si, //!< Instruction 'cvtss2si' {SSE}. + kIdCvttpd2dq, //!< Instruction 'cvttpd2dq' {SSE2}. + kIdCvttpd2pi, //!< Instruction 'cvttpd2pi' {SSE2}. + kIdCvttps2dq, //!< Instruction 'cvttps2dq' {SSE2}. + kIdCvttps2pi, //!< Instruction 'cvttps2pi' {SSE}. + kIdCvttsd2si, //!< Instruction 'cvttsd2si' {SSE2}. + kIdCvttss2si, //!< Instruction 'cvttss2si' {SSE}. + kIdCwd, //!< Instruction 'cwd'. + kIdCwde, //!< Instruction 'cwde'. + kIdDaa, //!< Instruction 'daa' (X86). + kIdDas, //!< Instruction 'das' (X86). + kIdDec, //!< Instruction 'dec'. + kIdDiv, //!< Instruction 'div'. + kIdDivpd, //!< Instruction 'divpd' {SSE2}. + kIdDivps, //!< Instruction 'divps' {SSE}. + kIdDivsd, //!< Instruction 'divsd' {SSE2}. + kIdDivss, //!< Instruction 'divss' {SSE}. + kIdDppd, //!< Instruction 'dppd' {SSE4_1}. + kIdDpps, //!< Instruction 'dpps' {SSE4_1}. + kIdEmms, //!< Instruction 'emms' {MMX}. + kIdEndbr32, //!< Instruction 'endbr32' {CET_IBT}. + kIdEndbr64, //!< Instruction 'endbr64' {CET_IBT}. + kIdEnqcmd, //!< Instruction 'enqcmd' {ENQCMD}. + kIdEnqcmds, //!< Instruction 'enqcmds' {ENQCMD}. + kIdEnter, //!< Instruction 'enter'. + kIdExtractps, //!< Instruction 'extractps' {SSE4_1}. + kIdExtrq, //!< Instruction 'extrq' {SSE4A}. + kIdF2xm1, //!< Instruction 'f2xm1' {FPU}. + kIdFabs, //!< Instruction 'fabs' {FPU}. + kIdFadd, //!< Instruction 'fadd' {FPU}. + kIdFaddp, //!< Instruction 'faddp' {FPU}. + kIdFbld, //!< Instruction 'fbld' {FPU}. + kIdFbstp, //!< Instruction 'fbstp' {FPU}. + kIdFchs, //!< Instruction 'fchs' {FPU}. + kIdFclex, //!< Instruction 'fclex' {FPU}. + kIdFcmovb, //!< Instruction 'fcmovb' {CMOV|FPU}. + kIdFcmovbe, //!< Instruction 'fcmovbe' {CMOV|FPU}. + kIdFcmove, //!< Instruction 'fcmove' {CMOV|FPU}. + kIdFcmovnb, //!< Instruction 'fcmovnb' {CMOV|FPU}. + kIdFcmovnbe, //!< Instruction 'fcmovnbe' {CMOV|FPU}. + kIdFcmovne, //!< Instruction 'fcmovne' {CMOV|FPU}. + kIdFcmovnu, //!< Instruction 'fcmovnu' {CMOV|FPU}. + kIdFcmovu, //!< Instruction 'fcmovu' {CMOV|FPU}. + kIdFcom, //!< Instruction 'fcom' {FPU}. + kIdFcomi, //!< Instruction 'fcomi' {FPU}. + kIdFcomip, //!< Instruction 'fcomip' {FPU}. + kIdFcomp, //!< Instruction 'fcomp' {FPU}. + kIdFcompp, //!< Instruction 'fcompp' {FPU}. + kIdFcos, //!< Instruction 'fcos' {FPU}. + kIdFdecstp, //!< Instruction 'fdecstp' {FPU}. + kIdFdiv, //!< Instruction 'fdiv' {FPU}. + kIdFdivp, //!< Instruction 'fdivp' {FPU}. + kIdFdivr, //!< Instruction 'fdivr' {FPU}. + kIdFdivrp, //!< Instruction 'fdivrp' {FPU}. + kIdFemms, //!< Instruction 'femms' {3DNOW}. + kIdFfree, //!< Instruction 'ffree' {FPU}. + kIdFiadd, //!< Instruction 'fiadd' {FPU}. + kIdFicom, //!< Instruction 'ficom' {FPU}. + kIdFicomp, //!< Instruction 'ficomp' {FPU}. + kIdFidiv, //!< Instruction 'fidiv' {FPU}. + kIdFidivr, //!< Instruction 'fidivr' {FPU}. + kIdFild, //!< Instruction 'fild' {FPU}. + kIdFimul, //!< Instruction 'fimul' {FPU}. + kIdFincstp, //!< Instruction 'fincstp' {FPU}. + kIdFinit, //!< Instruction 'finit' {FPU}. + kIdFist, //!< Instruction 'fist' {FPU}. + kIdFistp, //!< Instruction 'fistp' {FPU}. + kIdFisttp, //!< Instruction 'fisttp' {SSE3|FPU}. + kIdFisub, //!< Instruction 'fisub' {FPU}. + kIdFisubr, //!< Instruction 'fisubr' {FPU}. + kIdFld, //!< Instruction 'fld' {FPU}. + kIdFld1, //!< Instruction 'fld1' {FPU}. + kIdFldcw, //!< Instruction 'fldcw' {FPU}. + kIdFldenv, //!< Instruction 'fldenv' {FPU}. + kIdFldl2e, //!< Instruction 'fldl2e' {FPU}. + kIdFldl2t, //!< Instruction 'fldl2t' {FPU}. + kIdFldlg2, //!< Instruction 'fldlg2' {FPU}. + kIdFldln2, //!< Instruction 'fldln2' {FPU}. + kIdFldpi, //!< Instruction 'fldpi' {FPU}. + kIdFldz, //!< Instruction 'fldz' {FPU}. + kIdFmul, //!< Instruction 'fmul' {FPU}. + kIdFmulp, //!< Instruction 'fmulp' {FPU}. + kIdFnclex, //!< Instruction 'fnclex' {FPU}. + kIdFninit, //!< Instruction 'fninit' {FPU}. + kIdFnop, //!< Instruction 'fnop' {FPU}. + kIdFnsave, //!< Instruction 'fnsave' {FPU}. + kIdFnstcw, //!< Instruction 'fnstcw' {FPU}. + kIdFnstenv, //!< Instruction 'fnstenv' {FPU}. + kIdFnstsw, //!< Instruction 'fnstsw' {FPU}. + kIdFpatan, //!< Instruction 'fpatan' {FPU}. + kIdFprem, //!< Instruction 'fprem' {FPU}. + kIdFprem1, //!< Instruction 'fprem1' {FPU}. + kIdFptan, //!< Instruction 'fptan' {FPU}. + kIdFrndint, //!< Instruction 'frndint' {FPU}. + kIdFrstor, //!< Instruction 'frstor' {FPU}. + kIdFsave, //!< Instruction 'fsave' {FPU}. + kIdFscale, //!< Instruction 'fscale' {FPU}. + kIdFsin, //!< Instruction 'fsin' {FPU}. + kIdFsincos, //!< Instruction 'fsincos' {FPU}. + kIdFsqrt, //!< Instruction 'fsqrt' {FPU}. + kIdFst, //!< Instruction 'fst' {FPU}. + kIdFstcw, //!< Instruction 'fstcw' {FPU}. + kIdFstenv, //!< Instruction 'fstenv' {FPU}. + kIdFstp, //!< Instruction 'fstp' {FPU}. + kIdFstsw, //!< Instruction 'fstsw' {FPU}. + kIdFsub, //!< Instruction 'fsub' {FPU}. + kIdFsubp, //!< Instruction 'fsubp' {FPU}. + kIdFsubr, //!< Instruction 'fsubr' {FPU}. + kIdFsubrp, //!< Instruction 'fsubrp' {FPU}. + kIdFtst, //!< Instruction 'ftst' {FPU}. + kIdFucom, //!< Instruction 'fucom' {FPU}. + kIdFucomi, //!< Instruction 'fucomi' {FPU}. + kIdFucomip, //!< Instruction 'fucomip' {FPU}. + kIdFucomp, //!< Instruction 'fucomp' {FPU}. + kIdFucompp, //!< Instruction 'fucompp' {FPU}. + kIdFwait, //!< Instruction 'fwait' {FPU}. + kIdFxam, //!< Instruction 'fxam' {FPU}. + kIdFxch, //!< Instruction 'fxch' {FPU}. + kIdFxrstor, //!< Instruction 'fxrstor' {FXSR}. + kIdFxrstor64, //!< Instruction 'fxrstor64' {FXSR} (X64). + kIdFxsave, //!< Instruction 'fxsave' {FXSR}. + kIdFxsave64, //!< Instruction 'fxsave64' {FXSR} (X64). + kIdFxtract, //!< Instruction 'fxtract' {FPU}. + kIdFyl2x, //!< Instruction 'fyl2x' {FPU}. + kIdFyl2xp1, //!< Instruction 'fyl2xp1' {FPU}. + kIdGetsec, //!< Instruction 'getsec' {SMX}. + kIdGf2p8affineinvqb, //!< Instruction 'gf2p8affineinvqb' {GFNI}. + kIdGf2p8affineqb, //!< Instruction 'gf2p8affineqb' {GFNI}. + kIdGf2p8mulb, //!< Instruction 'gf2p8mulb' {GFNI}. + kIdHaddpd, //!< Instruction 'haddpd' {SSE3}. + kIdHaddps, //!< Instruction 'haddps' {SSE3}. + kIdHlt, //!< Instruction 'hlt'. + kIdHreset, //!< Instruction 'hreset' {HRESET}. + kIdHsubpd, //!< Instruction 'hsubpd' {SSE3}. + kIdHsubps, //!< Instruction 'hsubps' {SSE3}. + kIdIdiv, //!< Instruction 'idiv'. + kIdImul, //!< Instruction 'imul'. + kIdIn, //!< Instruction 'in'. + kIdInc, //!< Instruction 'inc'. + kIdIncsspd, //!< Instruction 'incsspd' {CET_SS}. + kIdIncsspq, //!< Instruction 'incsspq' {CET_SS} (X64). + kIdIns, //!< Instruction 'ins'. + kIdInsertps, //!< Instruction 'insertps' {SSE4_1}. + kIdInsertq, //!< Instruction 'insertq' {SSE4A}. + kIdInt, //!< Instruction 'int'. + kIdInt3, //!< Instruction 'int3'. + kIdInto, //!< Instruction 'into' (X86). + kIdInvd, //!< Instruction 'invd' {I486}. + kIdInvept, //!< Instruction 'invept' {VMX}. + kIdInvlpg, //!< Instruction 'invlpg' {I486}. + kIdInvlpga, //!< Instruction 'invlpga' {SVM}. + kIdInvlpgb, //!< Instruction 'invlpgb' {INVLPGB}. + kIdInvpcid, //!< Instruction 'invpcid' {I486}. + kIdInvvpid, //!< Instruction 'invvpid' {VMX}. + kIdIret, //!< Instruction 'iret'. + kIdIretd, //!< Instruction 'iretd'. + kIdIretq, //!< Instruction 'iretq' (X64). + kIdJa, //!< Instruction 'ja'. + kIdJae, //!< Instruction 'jae'. + kIdJb, //!< Instruction 'jb'. + kIdJbe, //!< Instruction 'jbe'. + kIdJc, //!< Instruction 'jc'. + kIdJe, //!< Instruction 'je'. + kIdJecxz, //!< Instruction 'jecxz'. + kIdJg, //!< Instruction 'jg'. + kIdJge, //!< Instruction 'jge'. + kIdJl, //!< Instruction 'jl'. + kIdJle, //!< Instruction 'jle'. + kIdJmp, //!< Instruction 'jmp'. + kIdJna, //!< Instruction 'jna'. + kIdJnae, //!< Instruction 'jnae'. + kIdJnb, //!< Instruction 'jnb'. + kIdJnbe, //!< Instruction 'jnbe'. + kIdJnc, //!< Instruction 'jnc'. + kIdJne, //!< Instruction 'jne'. + kIdJng, //!< Instruction 'jng'. + kIdJnge, //!< Instruction 'jnge'. + kIdJnl, //!< Instruction 'jnl'. + kIdJnle, //!< Instruction 'jnle'. + kIdJno, //!< Instruction 'jno'. + kIdJnp, //!< Instruction 'jnp'. + kIdJns, //!< Instruction 'jns'. + kIdJnz, //!< Instruction 'jnz'. + kIdJo, //!< Instruction 'jo'. + kIdJp, //!< Instruction 'jp'. + kIdJpe, //!< Instruction 'jpe'. + kIdJpo, //!< Instruction 'jpo'. + kIdJs, //!< Instruction 'js'. + kIdJz, //!< Instruction 'jz'. + kIdKaddb, //!< Instruction 'kaddb' {AVX512_DQ}. + kIdKaddd, //!< Instruction 'kaddd' {AVX512_BW}. + kIdKaddq, //!< Instruction 'kaddq' {AVX512_BW}. + kIdKaddw, //!< Instruction 'kaddw' {AVX512_DQ}. + kIdKandb, //!< Instruction 'kandb' {AVX512_DQ}. + kIdKandd, //!< Instruction 'kandd' {AVX512_BW}. + kIdKandnb, //!< Instruction 'kandnb' {AVX512_DQ}. + kIdKandnd, //!< Instruction 'kandnd' {AVX512_BW}. + kIdKandnq, //!< Instruction 'kandnq' {AVX512_BW}. + kIdKandnw, //!< Instruction 'kandnw' {AVX512_F}. + kIdKandq, //!< Instruction 'kandq' {AVX512_BW}. + kIdKandw, //!< Instruction 'kandw' {AVX512_F}. + kIdKmovb, //!< Instruction 'kmovb' {AVX512_DQ}. + kIdKmovd, //!< Instruction 'kmovd' {AVX512_BW}. + kIdKmovq, //!< Instruction 'kmovq' {AVX512_BW}. + kIdKmovw, //!< Instruction 'kmovw' {AVX512_F}. + kIdKnotb, //!< Instruction 'knotb' {AVX512_DQ}. + kIdKnotd, //!< Instruction 'knotd' {AVX512_BW}. + kIdKnotq, //!< Instruction 'knotq' {AVX512_BW}. + kIdKnotw, //!< Instruction 'knotw' {AVX512_F}. + kIdKorb, //!< Instruction 'korb' {AVX512_DQ}. + kIdKord, //!< Instruction 'kord' {AVX512_BW}. + kIdKorq, //!< Instruction 'korq' {AVX512_BW}. + kIdKortestb, //!< Instruction 'kortestb' {AVX512_DQ}. + kIdKortestd, //!< Instruction 'kortestd' {AVX512_BW}. + kIdKortestq, //!< Instruction 'kortestq' {AVX512_BW}. + kIdKortestw, //!< Instruction 'kortestw' {AVX512_F}. + kIdKorw, //!< Instruction 'korw' {AVX512_F}. + kIdKshiftlb, //!< Instruction 'kshiftlb' {AVX512_DQ}. + kIdKshiftld, //!< Instruction 'kshiftld' {AVX512_BW}. + kIdKshiftlq, //!< Instruction 'kshiftlq' {AVX512_BW}. + kIdKshiftlw, //!< Instruction 'kshiftlw' {AVX512_F}. + kIdKshiftrb, //!< Instruction 'kshiftrb' {AVX512_DQ}. + kIdKshiftrd, //!< Instruction 'kshiftrd' {AVX512_BW}. + kIdKshiftrq, //!< Instruction 'kshiftrq' {AVX512_BW}. + kIdKshiftrw, //!< Instruction 'kshiftrw' {AVX512_F}. + kIdKtestb, //!< Instruction 'ktestb' {AVX512_DQ}. + kIdKtestd, //!< Instruction 'ktestd' {AVX512_BW}. + kIdKtestq, //!< Instruction 'ktestq' {AVX512_BW}. + kIdKtestw, //!< Instruction 'ktestw' {AVX512_DQ}. + kIdKunpckbw, //!< Instruction 'kunpckbw' {AVX512_F}. + kIdKunpckdq, //!< Instruction 'kunpckdq' {AVX512_BW}. + kIdKunpckwd, //!< Instruction 'kunpckwd' {AVX512_BW}. + kIdKxnorb, //!< Instruction 'kxnorb' {AVX512_DQ}. + kIdKxnord, //!< Instruction 'kxnord' {AVX512_BW}. + kIdKxnorq, //!< Instruction 'kxnorq' {AVX512_BW}. + kIdKxnorw, //!< Instruction 'kxnorw' {AVX512_F}. + kIdKxorb, //!< Instruction 'kxorb' {AVX512_DQ}. + kIdKxord, //!< Instruction 'kxord' {AVX512_BW}. + kIdKxorq, //!< Instruction 'kxorq' {AVX512_BW}. + kIdKxorw, //!< Instruction 'kxorw' {AVX512_F}. + kIdLahf, //!< Instruction 'lahf' {LAHFSAHF}. + kIdLar, //!< Instruction 'lar'. + kIdLcall, //!< Instruction 'lcall'. + kIdLddqu, //!< Instruction 'lddqu' {SSE3}. + kIdLdmxcsr, //!< Instruction 'ldmxcsr' {SSE}. + kIdLds, //!< Instruction 'lds' (X86). + kIdLdtilecfg, //!< Instruction 'ldtilecfg' {AMX_TILE} (X64). + kIdLea, //!< Instruction 'lea'. + kIdLeave, //!< Instruction 'leave'. + kIdLes, //!< Instruction 'les' (X86). + kIdLfence, //!< Instruction 'lfence' {SSE2}. + kIdLfs, //!< Instruction 'lfs'. + kIdLgdt, //!< Instruction 'lgdt'. + kIdLgs, //!< Instruction 'lgs'. + kIdLidt, //!< Instruction 'lidt'. + kIdLjmp, //!< Instruction 'ljmp'. + kIdLldt, //!< Instruction 'lldt'. + kIdLlwpcb, //!< Instruction 'llwpcb' {LWP}. + kIdLmsw, //!< Instruction 'lmsw'. + kIdLods, //!< Instruction 'lods'. + kIdLoop, //!< Instruction 'loop'. + kIdLoope, //!< Instruction 'loope'. + kIdLoopne, //!< Instruction 'loopne'. + kIdLsl, //!< Instruction 'lsl'. + kIdLss, //!< Instruction 'lss'. + kIdLtr, //!< Instruction 'ltr'. + kIdLwpins, //!< Instruction 'lwpins' {LWP}. + kIdLwpval, //!< Instruction 'lwpval' {LWP}. + kIdLzcnt, //!< Instruction 'lzcnt' {LZCNT}. + kIdMaskmovdqu, //!< Instruction 'maskmovdqu' {SSE2}. + kIdMaskmovq, //!< Instruction 'maskmovq' {MMX2}. + kIdMaxpd, //!< Instruction 'maxpd' {SSE2}. + kIdMaxps, //!< Instruction 'maxps' {SSE}. + kIdMaxsd, //!< Instruction 'maxsd' {SSE2}. + kIdMaxss, //!< Instruction 'maxss' {SSE}. + kIdMcommit, //!< Instruction 'mcommit' {MCOMMIT}. + kIdMfence, //!< Instruction 'mfence' {SSE2}. + kIdMinpd, //!< Instruction 'minpd' {SSE2}. + kIdMinps, //!< Instruction 'minps' {SSE}. + kIdMinsd, //!< Instruction 'minsd' {SSE2}. + kIdMinss, //!< Instruction 'minss' {SSE}. + kIdMonitor, //!< Instruction 'monitor' {MONITOR}. + kIdMonitorx, //!< Instruction 'monitorx' {MONITORX}. + kIdMov, //!< Instruction 'mov'. + kIdMovabs, //!< Instruction 'movabs'. + kIdMovapd, //!< Instruction 'movapd' {SSE2}. + kIdMovaps, //!< Instruction 'movaps' {SSE}. + kIdMovbe, //!< Instruction 'movbe' {MOVBE}. + kIdMovd, //!< Instruction 'movd' {MMX|SSE2}. + kIdMovddup, //!< Instruction 'movddup' {SSE3}. + kIdMovdir64b, //!< Instruction 'movdir64b' {MOVDIR64B}. + kIdMovdiri, //!< Instruction 'movdiri' {MOVDIRI}. + kIdMovdq2q, //!< Instruction 'movdq2q' {SSE2}. + kIdMovdqa, //!< Instruction 'movdqa' {SSE2}. + kIdMovdqu, //!< Instruction 'movdqu' {SSE2}. + kIdMovhlps, //!< Instruction 'movhlps' {SSE}. + kIdMovhpd, //!< Instruction 'movhpd' {SSE2}. + kIdMovhps, //!< Instruction 'movhps' {SSE}. + kIdMovlhps, //!< Instruction 'movlhps' {SSE}. + kIdMovlpd, //!< Instruction 'movlpd' {SSE2}. + kIdMovlps, //!< Instruction 'movlps' {SSE}. + kIdMovmskpd, //!< Instruction 'movmskpd' {SSE2}. + kIdMovmskps, //!< Instruction 'movmskps' {SSE}. + kIdMovntdq, //!< Instruction 'movntdq' {SSE2}. + kIdMovntdqa, //!< Instruction 'movntdqa' {SSE4_1}. + kIdMovnti, //!< Instruction 'movnti' {SSE2}. + kIdMovntpd, //!< Instruction 'movntpd' {SSE2}. + kIdMovntps, //!< Instruction 'movntps' {SSE}. + kIdMovntq, //!< Instruction 'movntq' {MMX2}. + kIdMovntsd, //!< Instruction 'movntsd' {SSE4A}. + kIdMovntss, //!< Instruction 'movntss' {SSE4A}. + kIdMovq, //!< Instruction 'movq' {MMX|SSE2}. + kIdMovq2dq, //!< Instruction 'movq2dq' {SSE2}. + kIdMovs, //!< Instruction 'movs'. + kIdMovsd, //!< Instruction 'movsd' {SSE2}. + kIdMovshdup, //!< Instruction 'movshdup' {SSE3}. + kIdMovsldup, //!< Instruction 'movsldup' {SSE3}. + kIdMovss, //!< Instruction 'movss' {SSE}. + kIdMovsx, //!< Instruction 'movsx'. + kIdMovsxd, //!< Instruction 'movsxd' (X64). + kIdMovupd, //!< Instruction 'movupd' {SSE2}. + kIdMovups, //!< Instruction 'movups' {SSE}. + kIdMovzx, //!< Instruction 'movzx'. + kIdMpsadbw, //!< Instruction 'mpsadbw' {SSE4_1}. + kIdMul, //!< Instruction 'mul'. + kIdMulpd, //!< Instruction 'mulpd' {SSE2}. + kIdMulps, //!< Instruction 'mulps' {SSE}. + kIdMulsd, //!< Instruction 'mulsd' {SSE2}. + kIdMulss, //!< Instruction 'mulss' {SSE}. + kIdMulx, //!< Instruction 'mulx' {BMI2}. + kIdMwait, //!< Instruction 'mwait' {MONITOR}. + kIdMwaitx, //!< Instruction 'mwaitx' {MONITORX}. + kIdNeg, //!< Instruction 'neg'. + kIdNop, //!< Instruction 'nop'. + kIdNot, //!< Instruction 'not'. + kIdOr, //!< Instruction 'or'. + kIdOrpd, //!< Instruction 'orpd' {SSE2}. + kIdOrps, //!< Instruction 'orps' {SSE}. + kIdOut, //!< Instruction 'out'. + kIdOuts, //!< Instruction 'outs'. + kIdPabsb, //!< Instruction 'pabsb' {SSSE3}. + kIdPabsd, //!< Instruction 'pabsd' {SSSE3}. + kIdPabsw, //!< Instruction 'pabsw' {SSSE3}. + kIdPackssdw, //!< Instruction 'packssdw' {MMX|SSE2}. + kIdPacksswb, //!< Instruction 'packsswb' {MMX|SSE2}. + kIdPackusdw, //!< Instruction 'packusdw' {SSE4_1}. + kIdPackuswb, //!< Instruction 'packuswb' {MMX|SSE2}. + kIdPaddb, //!< Instruction 'paddb' {MMX|SSE2}. + kIdPaddd, //!< Instruction 'paddd' {MMX|SSE2}. + kIdPaddq, //!< Instruction 'paddq' {SSE2}. + kIdPaddsb, //!< Instruction 'paddsb' {MMX|SSE2}. + kIdPaddsw, //!< Instruction 'paddsw' {MMX|SSE2}. + kIdPaddusb, //!< Instruction 'paddusb' {MMX|SSE2}. + kIdPaddusw, //!< Instruction 'paddusw' {MMX|SSE2}. + kIdPaddw, //!< Instruction 'paddw' {MMX|SSE2}. + kIdPalignr, //!< Instruction 'palignr' {SSSE3}. + kIdPand, //!< Instruction 'pand' {MMX|SSE2}. + kIdPandn, //!< Instruction 'pandn' {MMX|SSE2}. + kIdPause, //!< Instruction 'pause'. + kIdPavgb, //!< Instruction 'pavgb' {MMX2|SSE2}. + kIdPavgusb, //!< Instruction 'pavgusb' {3DNOW}. + kIdPavgw, //!< Instruction 'pavgw' {MMX2|SSE2}. + kIdPblendvb, //!< Instruction 'pblendvb' {SSE4_1}. + kIdPblendw, //!< Instruction 'pblendw' {SSE4_1}. + kIdPclmulqdq, //!< Instruction 'pclmulqdq' {PCLMULQDQ}. + kIdPcmpeqb, //!< Instruction 'pcmpeqb' {MMX|SSE2}. + kIdPcmpeqd, //!< Instruction 'pcmpeqd' {MMX|SSE2}. + kIdPcmpeqq, //!< Instruction 'pcmpeqq' {SSE4_1}. + kIdPcmpeqw, //!< Instruction 'pcmpeqw' {MMX|SSE2}. + kIdPcmpestri, //!< Instruction 'pcmpestri' {SSE4_2}. + kIdPcmpestrm, //!< Instruction 'pcmpestrm' {SSE4_2}. + kIdPcmpgtb, //!< Instruction 'pcmpgtb' {MMX|SSE2}. + kIdPcmpgtd, //!< Instruction 'pcmpgtd' {MMX|SSE2}. + kIdPcmpgtq, //!< Instruction 'pcmpgtq' {SSE4_2}. + kIdPcmpgtw, //!< Instruction 'pcmpgtw' {MMX|SSE2}. + kIdPcmpistri, //!< Instruction 'pcmpistri' {SSE4_2}. + kIdPcmpistrm, //!< Instruction 'pcmpistrm' {SSE4_2}. + kIdPconfig, //!< Instruction 'pconfig' {PCONFIG}. + kIdPdep, //!< Instruction 'pdep' {BMI2}. + kIdPext, //!< Instruction 'pext' {BMI2}. + kIdPextrb, //!< Instruction 'pextrb' {SSE4_1}. + kIdPextrd, //!< Instruction 'pextrd' {SSE4_1}. + kIdPextrq, //!< Instruction 'pextrq' {SSE4_1} (X64). + kIdPextrw, //!< Instruction 'pextrw' {MMX2|SSE2|SSE4_1}. + kIdPf2id, //!< Instruction 'pf2id' {3DNOW}. + kIdPf2iw, //!< Instruction 'pf2iw' {3DNOW2}. + kIdPfacc, //!< Instruction 'pfacc' {3DNOW}. + kIdPfadd, //!< Instruction 'pfadd' {3DNOW}. + kIdPfcmpeq, //!< Instruction 'pfcmpeq' {3DNOW}. + kIdPfcmpge, //!< Instruction 'pfcmpge' {3DNOW}. + kIdPfcmpgt, //!< Instruction 'pfcmpgt' {3DNOW}. + kIdPfmax, //!< Instruction 'pfmax' {3DNOW}. + kIdPfmin, //!< Instruction 'pfmin' {3DNOW}. + kIdPfmul, //!< Instruction 'pfmul' {3DNOW}. + kIdPfnacc, //!< Instruction 'pfnacc' {3DNOW2}. + kIdPfpnacc, //!< Instruction 'pfpnacc' {3DNOW2}. + kIdPfrcp, //!< Instruction 'pfrcp' {3DNOW}. + kIdPfrcpit1, //!< Instruction 'pfrcpit1' {3DNOW}. + kIdPfrcpit2, //!< Instruction 'pfrcpit2' {3DNOW}. + kIdPfrcpv, //!< Instruction 'pfrcpv' {GEODE}. + kIdPfrsqit1, //!< Instruction 'pfrsqit1' {3DNOW}. + kIdPfrsqrt, //!< Instruction 'pfrsqrt' {3DNOW}. + kIdPfrsqrtv, //!< Instruction 'pfrsqrtv' {GEODE}. + kIdPfsub, //!< Instruction 'pfsub' {3DNOW}. + kIdPfsubr, //!< Instruction 'pfsubr' {3DNOW}. + kIdPhaddd, //!< Instruction 'phaddd' {SSSE3}. + kIdPhaddsw, //!< Instruction 'phaddsw' {SSSE3}. + kIdPhaddw, //!< Instruction 'phaddw' {SSSE3}. + kIdPhminposuw, //!< Instruction 'phminposuw' {SSE4_1}. + kIdPhsubd, //!< Instruction 'phsubd' {SSSE3}. + kIdPhsubsw, //!< Instruction 'phsubsw' {SSSE3}. + kIdPhsubw, //!< Instruction 'phsubw' {SSSE3}. + kIdPi2fd, //!< Instruction 'pi2fd' {3DNOW}. + kIdPi2fw, //!< Instruction 'pi2fw' {3DNOW2}. + kIdPinsrb, //!< Instruction 'pinsrb' {SSE4_1}. + kIdPinsrd, //!< Instruction 'pinsrd' {SSE4_1}. + kIdPinsrq, //!< Instruction 'pinsrq' {SSE4_1} (X64). + kIdPinsrw, //!< Instruction 'pinsrw' {MMX2|SSE2}. + kIdPmaddubsw, //!< Instruction 'pmaddubsw' {SSSE3}. + kIdPmaddwd, //!< Instruction 'pmaddwd' {MMX|SSE2}. + kIdPmaxsb, //!< Instruction 'pmaxsb' {SSE4_1}. + kIdPmaxsd, //!< Instruction 'pmaxsd' {SSE4_1}. + kIdPmaxsw, //!< Instruction 'pmaxsw' {MMX2|SSE2}. + kIdPmaxub, //!< Instruction 'pmaxub' {MMX2|SSE2}. + kIdPmaxud, //!< Instruction 'pmaxud' {SSE4_1}. + kIdPmaxuw, //!< Instruction 'pmaxuw' {SSE4_1}. + kIdPminsb, //!< Instruction 'pminsb' {SSE4_1}. + kIdPminsd, //!< Instruction 'pminsd' {SSE4_1}. + kIdPminsw, //!< Instruction 'pminsw' {MMX2|SSE2}. + kIdPminub, //!< Instruction 'pminub' {MMX2|SSE2}. + kIdPminud, //!< Instruction 'pminud' {SSE4_1}. + kIdPminuw, //!< Instruction 'pminuw' {SSE4_1}. + kIdPmovmskb, //!< Instruction 'pmovmskb' {MMX2|SSE2}. + kIdPmovsxbd, //!< Instruction 'pmovsxbd' {SSE4_1}. + kIdPmovsxbq, //!< Instruction 'pmovsxbq' {SSE4_1}. + kIdPmovsxbw, //!< Instruction 'pmovsxbw' {SSE4_1}. + kIdPmovsxdq, //!< Instruction 'pmovsxdq' {SSE4_1}. + kIdPmovsxwd, //!< Instruction 'pmovsxwd' {SSE4_1}. + kIdPmovsxwq, //!< Instruction 'pmovsxwq' {SSE4_1}. + kIdPmovzxbd, //!< Instruction 'pmovzxbd' {SSE4_1}. + kIdPmovzxbq, //!< Instruction 'pmovzxbq' {SSE4_1}. + kIdPmovzxbw, //!< Instruction 'pmovzxbw' {SSE4_1}. + kIdPmovzxdq, //!< Instruction 'pmovzxdq' {SSE4_1}. + kIdPmovzxwd, //!< Instruction 'pmovzxwd' {SSE4_1}. + kIdPmovzxwq, //!< Instruction 'pmovzxwq' {SSE4_1}. + kIdPmuldq, //!< Instruction 'pmuldq' {SSE4_1}. + kIdPmulhrsw, //!< Instruction 'pmulhrsw' {SSSE3}. + kIdPmulhrw, //!< Instruction 'pmulhrw' {3DNOW}. + kIdPmulhuw, //!< Instruction 'pmulhuw' {MMX2|SSE2}. + kIdPmulhw, //!< Instruction 'pmulhw' {MMX|SSE2}. + kIdPmulld, //!< Instruction 'pmulld' {SSE4_1}. + kIdPmullw, //!< Instruction 'pmullw' {MMX|SSE2}. + kIdPmuludq, //!< Instruction 'pmuludq' {SSE2}. + kIdPop, //!< Instruction 'pop'. + kIdPopa, //!< Instruction 'popa' (X86). + kIdPopad, //!< Instruction 'popad' (X86). + kIdPopcnt, //!< Instruction 'popcnt' {POPCNT}. + kIdPopf, //!< Instruction 'popf'. + kIdPopfd, //!< Instruction 'popfd' (X86). + kIdPopfq, //!< Instruction 'popfq' (X64). + kIdPor, //!< Instruction 'por' {MMX|SSE2}. + kIdPrefetch, //!< Instruction 'prefetch' {3DNOW}. + kIdPrefetchit0, //!< Instruction 'prefetchit0' {PREFETCHI} (X64). + kIdPrefetchit1, //!< Instruction 'prefetchit1' {PREFETCHI} (X64). + kIdPrefetchnta, //!< Instruction 'prefetchnta' {SSE}. + kIdPrefetcht0, //!< Instruction 'prefetcht0' {SSE}. + kIdPrefetcht1, //!< Instruction 'prefetcht1' {SSE}. + kIdPrefetcht2, //!< Instruction 'prefetcht2' {SSE}. + kIdPrefetchw, //!< Instruction 'prefetchw' {PREFETCHW}. + kIdPrefetchwt1, //!< Instruction 'prefetchwt1' {PREFETCHWT1}. + kIdPsadbw, //!< Instruction 'psadbw' {MMX2|SSE2}. + kIdPshufb, //!< Instruction 'pshufb' {SSSE3}. + kIdPshufd, //!< Instruction 'pshufd' {SSE2}. + kIdPshufhw, //!< Instruction 'pshufhw' {SSE2}. + kIdPshuflw, //!< Instruction 'pshuflw' {SSE2}. + kIdPshufw, //!< Instruction 'pshufw' {MMX2}. + kIdPsignb, //!< Instruction 'psignb' {SSSE3}. + kIdPsignd, //!< Instruction 'psignd' {SSSE3}. + kIdPsignw, //!< Instruction 'psignw' {SSSE3}. + kIdPslld, //!< Instruction 'pslld' {MMX|SSE2}. + kIdPslldq, //!< Instruction 'pslldq' {SSE2}. + kIdPsllq, //!< Instruction 'psllq' {MMX|SSE2}. + kIdPsllw, //!< Instruction 'psllw' {MMX|SSE2}. + kIdPsmash, //!< Instruction 'psmash' {SEV_SNP} (X64). + kIdPsrad, //!< Instruction 'psrad' {MMX|SSE2}. + kIdPsraw, //!< Instruction 'psraw' {MMX|SSE2}. + kIdPsrld, //!< Instruction 'psrld' {MMX|SSE2}. + kIdPsrldq, //!< Instruction 'psrldq' {SSE2}. + kIdPsrlq, //!< Instruction 'psrlq' {MMX|SSE2}. + kIdPsrlw, //!< Instruction 'psrlw' {MMX|SSE2}. + kIdPsubb, //!< Instruction 'psubb' {MMX|SSE2}. + kIdPsubd, //!< Instruction 'psubd' {MMX|SSE2}. + kIdPsubq, //!< Instruction 'psubq' {SSE2}. + kIdPsubsb, //!< Instruction 'psubsb' {MMX|SSE2}. + kIdPsubsw, //!< Instruction 'psubsw' {MMX|SSE2}. + kIdPsubusb, //!< Instruction 'psubusb' {MMX|SSE2}. + kIdPsubusw, //!< Instruction 'psubusw' {MMX|SSE2}. + kIdPsubw, //!< Instruction 'psubw' {MMX|SSE2}. + kIdPswapd, //!< Instruction 'pswapd' {3DNOW2}. + kIdPtest, //!< Instruction 'ptest' {SSE4_1}. + kIdPtwrite, //!< Instruction 'ptwrite' {PTWRITE}. + kIdPunpckhbw, //!< Instruction 'punpckhbw' {MMX|SSE2}. + kIdPunpckhdq, //!< Instruction 'punpckhdq' {MMX|SSE2}. + kIdPunpckhqdq, //!< Instruction 'punpckhqdq' {SSE2}. + kIdPunpckhwd, //!< Instruction 'punpckhwd' {MMX|SSE2}. + kIdPunpcklbw, //!< Instruction 'punpcklbw' {MMX|SSE2}. + kIdPunpckldq, //!< Instruction 'punpckldq' {MMX|SSE2}. + kIdPunpcklqdq, //!< Instruction 'punpcklqdq' {SSE2}. + kIdPunpcklwd, //!< Instruction 'punpcklwd' {MMX|SSE2}. + kIdPush, //!< Instruction 'push'. + kIdPusha, //!< Instruction 'pusha' (X86). + kIdPushad, //!< Instruction 'pushad' (X86). + kIdPushf, //!< Instruction 'pushf'. + kIdPushfd, //!< Instruction 'pushfd' (X86). + kIdPushfq, //!< Instruction 'pushfq' (X64). + kIdPvalidate, //!< Instruction 'pvalidate' {SEV_SNP}. + kIdPxor, //!< Instruction 'pxor' {MMX|SSE2}. + kIdRcl, //!< Instruction 'rcl'. + kIdRcpps, //!< Instruction 'rcpps' {SSE}. + kIdRcpss, //!< Instruction 'rcpss' {SSE}. + kIdRcr, //!< Instruction 'rcr'. + kIdRdfsbase, //!< Instruction 'rdfsbase' {FSGSBASE} (X64). + kIdRdgsbase, //!< Instruction 'rdgsbase' {FSGSBASE} (X64). + kIdRdmsr, //!< Instruction 'rdmsr' {MSR}. + kIdRdpid, //!< Instruction 'rdpid' {RDPID}. + kIdRdpkru, //!< Instruction 'rdpkru' {OSPKE}. + kIdRdpmc, //!< Instruction 'rdpmc'. + kIdRdpru, //!< Instruction 'rdpru' {RDPRU}. + kIdRdrand, //!< Instruction 'rdrand' {RDRAND}. + kIdRdseed, //!< Instruction 'rdseed' {RDSEED}. + kIdRdsspd, //!< Instruction 'rdsspd' {CET_SS}. + kIdRdsspq, //!< Instruction 'rdsspq' {CET_SS} (X64). + kIdRdtsc, //!< Instruction 'rdtsc' {RDTSC}. + kIdRdtscp, //!< Instruction 'rdtscp' {RDTSCP}. + kIdRet, //!< Instruction 'ret'. + kIdRetf, //!< Instruction 'retf'. + kIdRmpadjust, //!< Instruction 'rmpadjust' {SEV_SNP} (X64). + kIdRmpupdate, //!< Instruction 'rmpupdate' {SEV_SNP} (X64). + kIdRol, //!< Instruction 'rol'. + kIdRor, //!< Instruction 'ror'. + kIdRorx, //!< Instruction 'rorx' {BMI2}. + kIdRoundpd, //!< Instruction 'roundpd' {SSE4_1}. + kIdRoundps, //!< Instruction 'roundps' {SSE4_1}. + kIdRoundsd, //!< Instruction 'roundsd' {SSE4_1}. + kIdRoundss, //!< Instruction 'roundss' {SSE4_1}. + kIdRsm, //!< Instruction 'rsm' (X86). + kIdRsqrtps, //!< Instruction 'rsqrtps' {SSE}. + kIdRsqrtss, //!< Instruction 'rsqrtss' {SSE}. + kIdRstorssp, //!< Instruction 'rstorssp' {CET_SS}. + kIdSahf, //!< Instruction 'sahf' {LAHFSAHF}. + kIdSal, //!< Instruction 'sal'. + kIdSar, //!< Instruction 'sar'. + kIdSarx, //!< Instruction 'sarx' {BMI2}. + kIdSaveprevssp, //!< Instruction 'saveprevssp' {CET_SS}. + kIdSbb, //!< Instruction 'sbb'. + kIdScas, //!< Instruction 'scas'. + kIdSeamcall, //!< Instruction 'seamcall' {SEAM}. + kIdSeamops, //!< Instruction 'seamops' {SEAM}. + kIdSeamret, //!< Instruction 'seamret' {SEAM}. + kIdSenduipi, //!< Instruction 'senduipi' {UINTR} (X64). + kIdSerialize, //!< Instruction 'serialize' {SERIALIZE}. + kIdSeta, //!< Instruction 'seta'. + kIdSetae, //!< Instruction 'setae'. + kIdSetb, //!< Instruction 'setb'. + kIdSetbe, //!< Instruction 'setbe'. + kIdSetc, //!< Instruction 'setc'. + kIdSete, //!< Instruction 'sete'. + kIdSetg, //!< Instruction 'setg'. + kIdSetge, //!< Instruction 'setge'. + kIdSetl, //!< Instruction 'setl'. + kIdSetle, //!< Instruction 'setle'. + kIdSetna, //!< Instruction 'setna'. + kIdSetnae, //!< Instruction 'setnae'. + kIdSetnb, //!< Instruction 'setnb'. + kIdSetnbe, //!< Instruction 'setnbe'. + kIdSetnc, //!< Instruction 'setnc'. + kIdSetne, //!< Instruction 'setne'. + kIdSetng, //!< Instruction 'setng'. + kIdSetnge, //!< Instruction 'setnge'. + kIdSetnl, //!< Instruction 'setnl'. + kIdSetnle, //!< Instruction 'setnle'. + kIdSetno, //!< Instruction 'setno'. + kIdSetnp, //!< Instruction 'setnp'. + kIdSetns, //!< Instruction 'setns'. + kIdSetnz, //!< Instruction 'setnz'. + kIdSeto, //!< Instruction 'seto'. + kIdSetp, //!< Instruction 'setp'. + kIdSetpe, //!< Instruction 'setpe'. + kIdSetpo, //!< Instruction 'setpo'. + kIdSets, //!< Instruction 'sets'. + kIdSetssbsy, //!< Instruction 'setssbsy' {CET_SS}. + kIdSetz, //!< Instruction 'setz'. + kIdSfence, //!< Instruction 'sfence' {SSE}. + kIdSgdt, //!< Instruction 'sgdt'. + kIdSha1msg1, //!< Instruction 'sha1msg1' {SHA}. + kIdSha1msg2, //!< Instruction 'sha1msg2' {SHA}. + kIdSha1nexte, //!< Instruction 'sha1nexte' {SHA}. + kIdSha1rnds4, //!< Instruction 'sha1rnds4' {SHA}. + kIdSha256msg1, //!< Instruction 'sha256msg1' {SHA}. + kIdSha256msg2, //!< Instruction 'sha256msg2' {SHA}. + kIdSha256rnds2, //!< Instruction 'sha256rnds2' {SHA}. + kIdShl, //!< Instruction 'shl'. + kIdShld, //!< Instruction 'shld'. + kIdShlx, //!< Instruction 'shlx' {BMI2}. + kIdShr, //!< Instruction 'shr'. + kIdShrd, //!< Instruction 'shrd'. + kIdShrx, //!< Instruction 'shrx' {BMI2}. + kIdShufpd, //!< Instruction 'shufpd' {SSE2}. + kIdShufps, //!< Instruction 'shufps' {SSE}. + kIdSidt, //!< Instruction 'sidt'. + kIdSkinit, //!< Instruction 'skinit' {SKINIT}. + kIdSldt, //!< Instruction 'sldt'. + kIdSlwpcb, //!< Instruction 'slwpcb' {LWP}. + kIdSmsw, //!< Instruction 'smsw'. + kIdSqrtpd, //!< Instruction 'sqrtpd' {SSE2}. + kIdSqrtps, //!< Instruction 'sqrtps' {SSE}. + kIdSqrtsd, //!< Instruction 'sqrtsd' {SSE2}. + kIdSqrtss, //!< Instruction 'sqrtss' {SSE}. + kIdStac, //!< Instruction 'stac' {SMAP}. + kIdStc, //!< Instruction 'stc'. + kIdStd, //!< Instruction 'std'. + kIdStgi, //!< Instruction 'stgi' {SKINIT}. + kIdSti, //!< Instruction 'sti'. + kIdStmxcsr, //!< Instruction 'stmxcsr' {SSE}. + kIdStos, //!< Instruction 'stos'. + kIdStr, //!< Instruction 'str'. + kIdSttilecfg, //!< Instruction 'sttilecfg' {AMX_TILE} (X64). + kIdStui, //!< Instruction 'stui' {UINTR} (X64). + kIdSub, //!< Instruction 'sub'. + kIdSubpd, //!< Instruction 'subpd' {SSE2}. + kIdSubps, //!< Instruction 'subps' {SSE}. + kIdSubsd, //!< Instruction 'subsd' {SSE2}. + kIdSubss, //!< Instruction 'subss' {SSE}. + kIdSwapgs, //!< Instruction 'swapgs' (X64). + kIdSyscall, //!< Instruction 'syscall' (X64). + kIdSysenter, //!< Instruction 'sysenter'. + kIdSysexit, //!< Instruction 'sysexit'. + kIdSysexitq, //!< Instruction 'sysexitq' (X64). + kIdSysret, //!< Instruction 'sysret' (X64). + kIdSysretq, //!< Instruction 'sysretq' (X64). + kIdT1mskc, //!< Instruction 't1mskc' {TBM}. + kIdTcmmimfp16ps, //!< Instruction 'tcmmimfp16ps' {AMX_COMPLEX} (X64). + kIdTcmmrlfp16ps, //!< Instruction 'tcmmrlfp16ps' {AMX_COMPLEX} (X64). + kIdTdcall, //!< Instruction 'tdcall' {SEAM}. + kIdTdpbf16ps, //!< Instruction 'tdpbf16ps' {AMX_BF16} (X64). + kIdTdpbssd, //!< Instruction 'tdpbssd' {AMX_INT8} (X64). + kIdTdpbsud, //!< Instruction 'tdpbsud' {AMX_INT8} (X64). + kIdTdpbusd, //!< Instruction 'tdpbusd' {AMX_INT8} (X64). + kIdTdpbuud, //!< Instruction 'tdpbuud' {AMX_INT8} (X64). + kIdTdpfp16ps, //!< Instruction 'tdpfp16ps' {AMX_FP16} (X64). + kIdTest, //!< Instruction 'test'. + kIdTestui, //!< Instruction 'testui' {UINTR} (X64). + kIdTileloadd, //!< Instruction 'tileloadd' {AMX_TILE} (X64). + kIdTileloaddt1, //!< Instruction 'tileloaddt1' {AMX_TILE} (X64). + kIdTilerelease, //!< Instruction 'tilerelease' {AMX_TILE} (X64). + kIdTilestored, //!< Instruction 'tilestored' {AMX_TILE} (X64). + kIdTilezero, //!< Instruction 'tilezero' {AMX_TILE} (X64). + kIdTlbsync, //!< Instruction 'tlbsync' {INVLPGB}. + kIdTpause, //!< Instruction 'tpause' {WAITPKG}. + kIdTzcnt, //!< Instruction 'tzcnt' {BMI}. + kIdTzmsk, //!< Instruction 'tzmsk' {TBM}. + kIdUcomisd, //!< Instruction 'ucomisd' {SSE2}. + kIdUcomiss, //!< Instruction 'ucomiss' {SSE}. + kIdUd0, //!< Instruction 'ud0'. + kIdUd1, //!< Instruction 'ud1'. + kIdUd2, //!< Instruction 'ud2'. + kIdUiret, //!< Instruction 'uiret' {UINTR} (X64). + kIdUmonitor, //!< Instruction 'umonitor' {WAITPKG}. + kIdUmwait, //!< Instruction 'umwait' {WAITPKG}. + kIdUnpckhpd, //!< Instruction 'unpckhpd' {SSE2}. + kIdUnpckhps, //!< Instruction 'unpckhps' {SSE}. + kIdUnpcklpd, //!< Instruction 'unpcklpd' {SSE2}. + kIdUnpcklps, //!< Instruction 'unpcklps' {SSE}. + kIdV4fmaddps, //!< Instruction 'v4fmaddps' {AVX512_4FMAPS}. + kIdV4fmaddss, //!< Instruction 'v4fmaddss' {AVX512_4FMAPS}. + kIdV4fnmaddps, //!< Instruction 'v4fnmaddps' {AVX512_4FMAPS}. + kIdV4fnmaddss, //!< Instruction 'v4fnmaddss' {AVX512_4FMAPS}. + kIdVaddpd, //!< Instruction 'vaddpd' {AVX|AVX512_F+VL}. + kIdVaddph, //!< Instruction 'vaddph' {AVX512_FP16+VL}. + kIdVaddps, //!< Instruction 'vaddps' {AVX|AVX512_F+VL}. + kIdVaddsd, //!< Instruction 'vaddsd' {AVX|AVX512_F}. + kIdVaddsh, //!< Instruction 'vaddsh' {AVX512_FP16}. + kIdVaddss, //!< Instruction 'vaddss' {AVX|AVX512_F}. + kIdVaddsubpd, //!< Instruction 'vaddsubpd' {AVX}. + kIdVaddsubps, //!< Instruction 'vaddsubps' {AVX}. + kIdVaesdec, //!< Instruction 'vaesdec' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesdeclast, //!< Instruction 'vaesdeclast' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesenc, //!< Instruction 'vaesenc' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesenclast, //!< Instruction 'vaesenclast' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesimc, //!< Instruction 'vaesimc' {AVX & AESNI}. + kIdVaeskeygenassist, //!< Instruction 'vaeskeygenassist' {AVX & AESNI}. + kIdValignd, //!< Instruction 'valignd' {AVX512_F+VL}. + kIdValignq, //!< Instruction 'valignq' {AVX512_F+VL}. + kIdVandnpd, //!< Instruction 'vandnpd' {AVX|AVX512_DQ+VL}. + kIdVandnps, //!< Instruction 'vandnps' {AVX|AVX512_DQ+VL}. + kIdVandpd, //!< Instruction 'vandpd' {AVX|AVX512_DQ+VL}. + kIdVandps, //!< Instruction 'vandps' {AVX|AVX512_DQ+VL}. + kIdVbcstnebf162ps, //!< Instruction 'vbcstnebf162ps' {AVX_NE_CONVERT}. + kIdVbcstnesh2ps, //!< Instruction 'vbcstnesh2ps' {AVX_NE_CONVERT}. + kIdVblendmpd, //!< Instruction 'vblendmpd' {AVX512_F+VL}. + kIdVblendmps, //!< Instruction 'vblendmps' {AVX512_F+VL}. + kIdVblendpd, //!< Instruction 'vblendpd' {AVX}. + kIdVblendps, //!< Instruction 'vblendps' {AVX}. + kIdVblendvpd, //!< Instruction 'vblendvpd' {AVX}. + kIdVblendvps, //!< Instruction 'vblendvps' {AVX}. + kIdVbroadcastf128, //!< Instruction 'vbroadcastf128' {AVX}. + kIdVbroadcastf32x2, //!< Instruction 'vbroadcastf32x2' {AVX512_DQ+VL}. + kIdVbroadcastf32x4, //!< Instruction 'vbroadcastf32x4' {AVX512_F}. + kIdVbroadcastf32x8, //!< Instruction 'vbroadcastf32x8' {AVX512_DQ}. + kIdVbroadcastf64x2, //!< Instruction 'vbroadcastf64x2' {AVX512_DQ+VL}. + kIdVbroadcastf64x4, //!< Instruction 'vbroadcastf64x4' {AVX512_F}. + kIdVbroadcasti128, //!< Instruction 'vbroadcasti128' {AVX2}. + kIdVbroadcasti32x2, //!< Instruction 'vbroadcasti32x2' {AVX512_DQ+VL}. + kIdVbroadcasti32x4, //!< Instruction 'vbroadcasti32x4' {AVX512_F+VL}. + kIdVbroadcasti32x8, //!< Instruction 'vbroadcasti32x8' {AVX512_DQ}. + kIdVbroadcasti64x2, //!< Instruction 'vbroadcasti64x2' {AVX512_DQ+VL}. + kIdVbroadcasti64x4, //!< Instruction 'vbroadcasti64x4' {AVX512_F}. + kIdVbroadcastsd, //!< Instruction 'vbroadcastsd' {AVX|AVX2|AVX512_F+VL}. + kIdVbroadcastss, //!< Instruction 'vbroadcastss' {AVX|AVX2|AVX512_F+VL}. + kIdVcmppd, //!< Instruction 'vcmppd' {AVX|AVX512_F+VL}. + kIdVcmpph, //!< Instruction 'vcmpph' {AVX512_FP16+VL}. + kIdVcmpps, //!< Instruction 'vcmpps' {AVX|AVX512_F+VL}. + kIdVcmpsd, //!< Instruction 'vcmpsd' {AVX|AVX512_F}. + kIdVcmpsh, //!< Instruction 'vcmpsh' {AVX512_FP16}. + kIdVcmpss, //!< Instruction 'vcmpss' {AVX|AVX512_F}. + kIdVcomisd, //!< Instruction 'vcomisd' {AVX|AVX512_F}. + kIdVcomish, //!< Instruction 'vcomish' {AVX512_FP16}. + kIdVcomiss, //!< Instruction 'vcomiss' {AVX|AVX512_F}. + kIdVcompresspd, //!< Instruction 'vcompresspd' {AVX512_F+VL}. + kIdVcompressps, //!< Instruction 'vcompressps' {AVX512_F+VL}. + kIdVcvtdq2pd, //!< Instruction 'vcvtdq2pd' {AVX|AVX512_F+VL}. + kIdVcvtdq2ph, //!< Instruction 'vcvtdq2ph' {AVX512_FP16+VL}. + kIdVcvtdq2ps, //!< Instruction 'vcvtdq2ps' {AVX|AVX512_F+VL}. + kIdVcvtne2ps2bf16, //!< Instruction 'vcvtne2ps2bf16' {AVX512_BF16+VL}. + kIdVcvtneebf162ps, //!< Instruction 'vcvtneebf162ps' {AVX_NE_CONVERT}. + kIdVcvtneeph2ps, //!< Instruction 'vcvtneeph2ps' {AVX_NE_CONVERT}. + kIdVcvtneobf162ps, //!< Instruction 'vcvtneobf162ps' {AVX_NE_CONVERT}. + kIdVcvtneoph2ps, //!< Instruction 'vcvtneoph2ps' {AVX_NE_CONVERT}. + kIdVcvtneps2bf16, //!< Instruction 'vcvtneps2bf16' {AVX_NE_CONVERT|AVX512_BF16+VL}. + kIdVcvtpd2dq, //!< Instruction 'vcvtpd2dq' {AVX|AVX512_F+VL}. + kIdVcvtpd2ph, //!< Instruction 'vcvtpd2ph' {AVX512_FP16+VL}. + kIdVcvtpd2ps, //!< Instruction 'vcvtpd2ps' {AVX|AVX512_F+VL}. + kIdVcvtpd2qq, //!< Instruction 'vcvtpd2qq' {AVX512_DQ+VL}. + kIdVcvtpd2udq, //!< Instruction 'vcvtpd2udq' {AVX512_F+VL}. + kIdVcvtpd2uqq, //!< Instruction 'vcvtpd2uqq' {AVX512_DQ+VL}. + kIdVcvtph2dq, //!< Instruction 'vcvtph2dq' {AVX512_FP16+VL}. + kIdVcvtph2pd, //!< Instruction 'vcvtph2pd' {AVX512_FP16+VL}. + kIdVcvtph2ps, //!< Instruction 'vcvtph2ps' {AVX512_F+VL & F16C}. + kIdVcvtph2psx, //!< Instruction 'vcvtph2psx' {AVX512_FP16+VL}. + kIdVcvtph2qq, //!< Instruction 'vcvtph2qq' {AVX512_FP16+VL}. + kIdVcvtph2udq, //!< Instruction 'vcvtph2udq' {AVX512_FP16+VL}. + kIdVcvtph2uqq, //!< Instruction 'vcvtph2uqq' {AVX512_FP16+VL}. + kIdVcvtph2uw, //!< Instruction 'vcvtph2uw' {AVX512_FP16+VL}. + kIdVcvtph2w, //!< Instruction 'vcvtph2w' {AVX512_FP16+VL}. + kIdVcvtps2dq, //!< Instruction 'vcvtps2dq' {AVX|AVX512_F+VL}. + kIdVcvtps2pd, //!< Instruction 'vcvtps2pd' {AVX|AVX512_F+VL}. + kIdVcvtps2ph, //!< Instruction 'vcvtps2ph' {AVX512_F+VL & F16C}. + kIdVcvtps2phx, //!< Instruction 'vcvtps2phx' {AVX512_FP16+VL}. + kIdVcvtps2qq, //!< Instruction 'vcvtps2qq' {AVX512_DQ+VL}. + kIdVcvtps2udq, //!< Instruction 'vcvtps2udq' {AVX512_F+VL}. + kIdVcvtps2uqq, //!< Instruction 'vcvtps2uqq' {AVX512_DQ+VL}. + kIdVcvtqq2pd, //!< Instruction 'vcvtqq2pd' {AVX512_DQ+VL}. + kIdVcvtqq2ph, //!< Instruction 'vcvtqq2ph' {AVX512_FP16+VL}. + kIdVcvtqq2ps, //!< Instruction 'vcvtqq2ps' {AVX512_DQ+VL}. + kIdVcvtsd2sh, //!< Instruction 'vcvtsd2sh' {AVX512_FP16}. + kIdVcvtsd2si, //!< Instruction 'vcvtsd2si' {AVX|AVX512_F}. + kIdVcvtsd2ss, //!< Instruction 'vcvtsd2ss' {AVX|AVX512_F}. + kIdVcvtsd2usi, //!< Instruction 'vcvtsd2usi' {AVX512_F}. + kIdVcvtsh2sd, //!< Instruction 'vcvtsh2sd' {AVX512_FP16}. + kIdVcvtsh2si, //!< Instruction 'vcvtsh2si' {AVX512_FP16}. + kIdVcvtsh2ss, //!< Instruction 'vcvtsh2ss' {AVX512_FP16}. + kIdVcvtsh2usi, //!< Instruction 'vcvtsh2usi' {AVX512_FP16}. + kIdVcvtsi2sd, //!< Instruction 'vcvtsi2sd' {AVX|AVX512_F}. + kIdVcvtsi2sh, //!< Instruction 'vcvtsi2sh' {AVX512_FP16}. + kIdVcvtsi2ss, //!< Instruction 'vcvtsi2ss' {AVX|AVX512_F}. + kIdVcvtss2sd, //!< Instruction 'vcvtss2sd' {AVX|AVX512_F}. + kIdVcvtss2sh, //!< Instruction 'vcvtss2sh' {AVX512_FP16}. + kIdVcvtss2si, //!< Instruction 'vcvtss2si' {AVX|AVX512_F}. + kIdVcvtss2usi, //!< Instruction 'vcvtss2usi' {AVX512_F}. + kIdVcvttpd2dq, //!< Instruction 'vcvttpd2dq' {AVX|AVX512_F+VL}. + kIdVcvttpd2qq, //!< Instruction 'vcvttpd2qq' {AVX512_F+VL}. + kIdVcvttpd2udq, //!< Instruction 'vcvttpd2udq' {AVX512_F+VL}. + kIdVcvttpd2uqq, //!< Instruction 'vcvttpd2uqq' {AVX512_DQ+VL}. + kIdVcvttph2dq, //!< Instruction 'vcvttph2dq' {AVX512_FP16+VL}. + kIdVcvttph2qq, //!< Instruction 'vcvttph2qq' {AVX512_FP16+VL}. + kIdVcvttph2udq, //!< Instruction 'vcvttph2udq' {AVX512_FP16+VL}. + kIdVcvttph2uqq, //!< Instruction 'vcvttph2uqq' {AVX512_FP16+VL}. + kIdVcvttph2uw, //!< Instruction 'vcvttph2uw' {AVX512_FP16+VL}. + kIdVcvttph2w, //!< Instruction 'vcvttph2w' {AVX512_FP16+VL}. + kIdVcvttps2dq, //!< Instruction 'vcvttps2dq' {AVX|AVX512_F+VL}. + kIdVcvttps2qq, //!< Instruction 'vcvttps2qq' {AVX512_DQ+VL}. + kIdVcvttps2udq, //!< Instruction 'vcvttps2udq' {AVX512_F+VL}. + kIdVcvttps2uqq, //!< Instruction 'vcvttps2uqq' {AVX512_DQ+VL}. + kIdVcvttsd2si, //!< Instruction 'vcvttsd2si' {AVX|AVX512_F}. + kIdVcvttsd2usi, //!< Instruction 'vcvttsd2usi' {AVX512_F}. + kIdVcvttsh2si, //!< Instruction 'vcvttsh2si' {AVX512_FP16}. + kIdVcvttsh2usi, //!< Instruction 'vcvttsh2usi' {AVX512_FP16}. + kIdVcvttss2si, //!< Instruction 'vcvttss2si' {AVX|AVX512_F}. + kIdVcvttss2usi, //!< Instruction 'vcvttss2usi' {AVX512_F}. + kIdVcvtudq2pd, //!< Instruction 'vcvtudq2pd' {AVX512_F+VL}. + kIdVcvtudq2ph, //!< Instruction 'vcvtudq2ph' {AVX512_FP16+VL}. + kIdVcvtudq2ps, //!< Instruction 'vcvtudq2ps' {AVX512_F+VL}. + kIdVcvtuqq2pd, //!< Instruction 'vcvtuqq2pd' {AVX512_DQ+VL}. + kIdVcvtuqq2ph, //!< Instruction 'vcvtuqq2ph' {AVX512_FP16+VL}. + kIdVcvtuqq2ps, //!< Instruction 'vcvtuqq2ps' {AVX512_DQ+VL}. + kIdVcvtusi2sd, //!< Instruction 'vcvtusi2sd' {AVX512_F}. + kIdVcvtusi2sh, //!< Instruction 'vcvtusi2sh' {AVX512_FP16}. + kIdVcvtusi2ss, //!< Instruction 'vcvtusi2ss' {AVX512_F}. + kIdVcvtuw2ph, //!< Instruction 'vcvtuw2ph' {AVX512_FP16+VL}. + kIdVcvtw2ph, //!< Instruction 'vcvtw2ph' {AVX512_FP16+VL}. + kIdVdbpsadbw, //!< Instruction 'vdbpsadbw' {AVX512_BW+VL}. + kIdVdivpd, //!< Instruction 'vdivpd' {AVX|AVX512_F+VL}. + kIdVdivph, //!< Instruction 'vdivph' {AVX512_FP16+VL}. + kIdVdivps, //!< Instruction 'vdivps' {AVX|AVX512_F+VL}. + kIdVdivsd, //!< Instruction 'vdivsd' {AVX|AVX512_F}. + kIdVdivsh, //!< Instruction 'vdivsh' {AVX512_FP16}. + kIdVdivss, //!< Instruction 'vdivss' {AVX|AVX512_F}. + kIdVdpbf16ps, //!< Instruction 'vdpbf16ps' {AVX512_BF16+VL}. + kIdVdppd, //!< Instruction 'vdppd' {AVX}. + kIdVdpps, //!< Instruction 'vdpps' {AVX}. + kIdVerr, //!< Instruction 'verr'. + kIdVerw, //!< Instruction 'verw'. + kIdVexp2pd, //!< Instruction 'vexp2pd' {AVX512_ER}. + kIdVexp2ps, //!< Instruction 'vexp2ps' {AVX512_ER}. + kIdVexpandpd, //!< Instruction 'vexpandpd' {AVX512_F+VL}. + kIdVexpandps, //!< Instruction 'vexpandps' {AVX512_F+VL}. + kIdVextractf128, //!< Instruction 'vextractf128' {AVX}. + kIdVextractf32x4, //!< Instruction 'vextractf32x4' {AVX512_F+VL}. + kIdVextractf32x8, //!< Instruction 'vextractf32x8' {AVX512_DQ}. + kIdVextractf64x2, //!< Instruction 'vextractf64x2' {AVX512_DQ+VL}. + kIdVextractf64x4, //!< Instruction 'vextractf64x4' {AVX512_F}. + kIdVextracti128, //!< Instruction 'vextracti128' {AVX2}. + kIdVextracti32x4, //!< Instruction 'vextracti32x4' {AVX512_F+VL}. + kIdVextracti32x8, //!< Instruction 'vextracti32x8' {AVX512_DQ}. + kIdVextracti64x2, //!< Instruction 'vextracti64x2' {AVX512_DQ+VL}. + kIdVextracti64x4, //!< Instruction 'vextracti64x4' {AVX512_F}. + kIdVextractps, //!< Instruction 'vextractps' {AVX|AVX512_F}. + kIdVfcmaddcph, //!< Instruction 'vfcmaddcph' {AVX512_FP16+VL}. + kIdVfcmaddcsh, //!< Instruction 'vfcmaddcsh' {AVX512_FP16}. + kIdVfcmulcph, //!< Instruction 'vfcmulcph' {AVX512_FP16+VL}. + kIdVfcmulcsh, //!< Instruction 'vfcmulcsh' {AVX512_FP16}. + kIdVfixupimmpd, //!< Instruction 'vfixupimmpd' {AVX512_F+VL}. + kIdVfixupimmps, //!< Instruction 'vfixupimmps' {AVX512_F+VL}. + kIdVfixupimmsd, //!< Instruction 'vfixupimmsd' {AVX512_F}. + kIdVfixupimmss, //!< Instruction 'vfixupimmss' {AVX512_F}. + kIdVfmadd132pd, //!< Instruction 'vfmadd132pd' {FMA|AVX512_F+VL}. + kIdVfmadd132ph, //!< Instruction 'vfmadd132ph' {AVX512_FP16+VL}. + kIdVfmadd132ps, //!< Instruction 'vfmadd132ps' {FMA|AVX512_F+VL}. + kIdVfmadd132sd, //!< Instruction 'vfmadd132sd' {FMA|AVX512_F}. + kIdVfmadd132sh, //!< Instruction 'vfmadd132sh' {AVX512_FP16}. + kIdVfmadd132ss, //!< Instruction 'vfmadd132ss' {FMA|AVX512_F}. + kIdVfmadd213pd, //!< Instruction 'vfmadd213pd' {FMA|AVX512_F+VL}. + kIdVfmadd213ph, //!< Instruction 'vfmadd213ph' {AVX512_FP16+VL}. + kIdVfmadd213ps, //!< Instruction 'vfmadd213ps' {FMA|AVX512_F+VL}. + kIdVfmadd213sd, //!< Instruction 'vfmadd213sd' {FMA|AVX512_F}. + kIdVfmadd213sh, //!< Instruction 'vfmadd213sh' {AVX512_FP16}. + kIdVfmadd213ss, //!< Instruction 'vfmadd213ss' {FMA|AVX512_F}. + kIdVfmadd231pd, //!< Instruction 'vfmadd231pd' {FMA|AVX512_F+VL}. + kIdVfmadd231ph, //!< Instruction 'vfmadd231ph' {AVX512_FP16+VL}. + kIdVfmadd231ps, //!< Instruction 'vfmadd231ps' {FMA|AVX512_F+VL}. + kIdVfmadd231sd, //!< Instruction 'vfmadd231sd' {FMA|AVX512_F}. + kIdVfmadd231sh, //!< Instruction 'vfmadd231sh' {AVX512_FP16}. + kIdVfmadd231ss, //!< Instruction 'vfmadd231ss' {FMA|AVX512_F}. + kIdVfmaddcph, //!< Instruction 'vfmaddcph' {AVX512_FP16+VL}. + kIdVfmaddcsh, //!< Instruction 'vfmaddcsh' {AVX512_FP16}. + kIdVfmaddpd, //!< Instruction 'vfmaddpd' {FMA4}. + kIdVfmaddps, //!< Instruction 'vfmaddps' {FMA4}. + kIdVfmaddsd, //!< Instruction 'vfmaddsd' {FMA4}. + kIdVfmaddss, //!< Instruction 'vfmaddss' {FMA4}. + kIdVfmaddsub132pd, //!< Instruction 'vfmaddsub132pd' {FMA|AVX512_F+VL}. + kIdVfmaddsub132ph, //!< Instruction 'vfmaddsub132ph' {AVX512_FP16+VL}. + kIdVfmaddsub132ps, //!< Instruction 'vfmaddsub132ps' {FMA|AVX512_F+VL}. + kIdVfmaddsub213pd, //!< Instruction 'vfmaddsub213pd' {FMA|AVX512_F+VL}. + kIdVfmaddsub213ph, //!< Instruction 'vfmaddsub213ph' {AVX512_FP16+VL}. + kIdVfmaddsub213ps, //!< Instruction 'vfmaddsub213ps' {FMA|AVX512_F+VL}. + kIdVfmaddsub231pd, //!< Instruction 'vfmaddsub231pd' {FMA|AVX512_F+VL}. + kIdVfmaddsub231ph, //!< Instruction 'vfmaddsub231ph' {AVX512_FP16+VL}. + kIdVfmaddsub231ps, //!< Instruction 'vfmaddsub231ps' {FMA|AVX512_F+VL}. + kIdVfmaddsubpd, //!< Instruction 'vfmaddsubpd' {FMA4}. + kIdVfmaddsubps, //!< Instruction 'vfmaddsubps' {FMA4}. + kIdVfmsub132pd, //!< Instruction 'vfmsub132pd' {FMA|AVX512_F+VL}. + kIdVfmsub132ph, //!< Instruction 'vfmsub132ph' {AVX512_FP16+VL}. + kIdVfmsub132ps, //!< Instruction 'vfmsub132ps' {FMA|AVX512_F+VL}. + kIdVfmsub132sd, //!< Instruction 'vfmsub132sd' {FMA|AVX512_F}. + kIdVfmsub132sh, //!< Instruction 'vfmsub132sh' {AVX512_FP16}. + kIdVfmsub132ss, //!< Instruction 'vfmsub132ss' {FMA|AVX512_F}. + kIdVfmsub213pd, //!< Instruction 'vfmsub213pd' {FMA|AVX512_F+VL}. + kIdVfmsub213ph, //!< Instruction 'vfmsub213ph' {AVX512_FP16+VL}. + kIdVfmsub213ps, //!< Instruction 'vfmsub213ps' {FMA|AVX512_F+VL}. + kIdVfmsub213sd, //!< Instruction 'vfmsub213sd' {FMA|AVX512_F}. + kIdVfmsub213sh, //!< Instruction 'vfmsub213sh' {AVX512_FP16}. + kIdVfmsub213ss, //!< Instruction 'vfmsub213ss' {FMA|AVX512_F}. + kIdVfmsub231pd, //!< Instruction 'vfmsub231pd' {FMA|AVX512_F+VL}. + kIdVfmsub231ph, //!< Instruction 'vfmsub231ph' {AVX512_FP16+VL}. + kIdVfmsub231ps, //!< Instruction 'vfmsub231ps' {FMA|AVX512_F+VL}. + kIdVfmsub231sd, //!< Instruction 'vfmsub231sd' {FMA|AVX512_F}. + kIdVfmsub231sh, //!< Instruction 'vfmsub231sh' {AVX512_FP16}. + kIdVfmsub231ss, //!< Instruction 'vfmsub231ss' {FMA|AVX512_F}. + kIdVfmsubadd132pd, //!< Instruction 'vfmsubadd132pd' {FMA|AVX512_F+VL}. + kIdVfmsubadd132ph, //!< Instruction 'vfmsubadd132ph' {AVX512_FP16+VL}. + kIdVfmsubadd132ps, //!< Instruction 'vfmsubadd132ps' {FMA|AVX512_F+VL}. + kIdVfmsubadd213pd, //!< Instruction 'vfmsubadd213pd' {FMA|AVX512_F+VL}. + kIdVfmsubadd213ph, //!< Instruction 'vfmsubadd213ph' {AVX512_FP16+VL}. + kIdVfmsubadd213ps, //!< Instruction 'vfmsubadd213ps' {FMA|AVX512_F+VL}. + kIdVfmsubadd231pd, //!< Instruction 'vfmsubadd231pd' {FMA|AVX512_F+VL}. + kIdVfmsubadd231ph, //!< Instruction 'vfmsubadd231ph' {AVX512_FP16+VL}. + kIdVfmsubadd231ps, //!< Instruction 'vfmsubadd231ps' {FMA|AVX512_F+VL}. + kIdVfmsubaddpd, //!< Instruction 'vfmsubaddpd' {FMA4}. + kIdVfmsubaddps, //!< Instruction 'vfmsubaddps' {FMA4}. + kIdVfmsubpd, //!< Instruction 'vfmsubpd' {FMA4}. + kIdVfmsubps, //!< Instruction 'vfmsubps' {FMA4}. + kIdVfmsubsd, //!< Instruction 'vfmsubsd' {FMA4}. + kIdVfmsubss, //!< Instruction 'vfmsubss' {FMA4}. + kIdVfmulcph, //!< Instruction 'vfmulcph' {AVX512_FP16+VL}. + kIdVfmulcsh, //!< Instruction 'vfmulcsh' {AVX512_FP16+VL}. + kIdVfnmadd132pd, //!< Instruction 'vfnmadd132pd' {FMA|AVX512_F+VL}. + kIdVfnmadd132ph, //!< Instruction 'vfnmadd132ph' {AVX512_FP16+VL}. + kIdVfnmadd132ps, //!< Instruction 'vfnmadd132ps' {FMA|AVX512_F+VL}. + kIdVfnmadd132sd, //!< Instruction 'vfnmadd132sd' {FMA|AVX512_F}. + kIdVfnmadd132sh, //!< Instruction 'vfnmadd132sh' {AVX512_FP16}. + kIdVfnmadd132ss, //!< Instruction 'vfnmadd132ss' {FMA|AVX512_F}. + kIdVfnmadd213pd, //!< Instruction 'vfnmadd213pd' {FMA|AVX512_F+VL}. + kIdVfnmadd213ph, //!< Instruction 'vfnmadd213ph' {AVX512_FP16+VL}. + kIdVfnmadd213ps, //!< Instruction 'vfnmadd213ps' {FMA|AVX512_F+VL}. + kIdVfnmadd213sd, //!< Instruction 'vfnmadd213sd' {FMA|AVX512_F}. + kIdVfnmadd213sh, //!< Instruction 'vfnmadd213sh' {AVX512_FP16}. + kIdVfnmadd213ss, //!< Instruction 'vfnmadd213ss' {FMA|AVX512_F}. + kIdVfnmadd231pd, //!< Instruction 'vfnmadd231pd' {FMA|AVX512_F+VL}. + kIdVfnmadd231ph, //!< Instruction 'vfnmadd231ph' {AVX512_FP16+VL}. + kIdVfnmadd231ps, //!< Instruction 'vfnmadd231ps' {FMA|AVX512_F+VL}. + kIdVfnmadd231sd, //!< Instruction 'vfnmadd231sd' {FMA|AVX512_F}. + kIdVfnmadd231sh, //!< Instruction 'vfnmadd231sh' {AVX512_FP16}. + kIdVfnmadd231ss, //!< Instruction 'vfnmadd231ss' {FMA|AVX512_F}. + kIdVfnmaddpd, //!< Instruction 'vfnmaddpd' {FMA4}. + kIdVfnmaddps, //!< Instruction 'vfnmaddps' {FMA4}. + kIdVfnmaddsd, //!< Instruction 'vfnmaddsd' {FMA4}. + kIdVfnmaddss, //!< Instruction 'vfnmaddss' {FMA4}. + kIdVfnmsub132pd, //!< Instruction 'vfnmsub132pd' {FMA|AVX512_F+VL}. + kIdVfnmsub132ph, //!< Instruction 'vfnmsub132ph' {AVX512_FP16+VL}. + kIdVfnmsub132ps, //!< Instruction 'vfnmsub132ps' {FMA|AVX512_F+VL}. + kIdVfnmsub132sd, //!< Instruction 'vfnmsub132sd' {FMA|AVX512_F}. + kIdVfnmsub132sh, //!< Instruction 'vfnmsub132sh' {AVX512_FP16}. + kIdVfnmsub132ss, //!< Instruction 'vfnmsub132ss' {FMA|AVX512_F}. + kIdVfnmsub213pd, //!< Instruction 'vfnmsub213pd' {FMA|AVX512_F+VL}. + kIdVfnmsub213ph, //!< Instruction 'vfnmsub213ph' {AVX512_FP16+VL}. + kIdVfnmsub213ps, //!< Instruction 'vfnmsub213ps' {FMA|AVX512_F+VL}. + kIdVfnmsub213sd, //!< Instruction 'vfnmsub213sd' {FMA|AVX512_F}. + kIdVfnmsub213sh, //!< Instruction 'vfnmsub213sh' {AVX512_FP16}. + kIdVfnmsub213ss, //!< Instruction 'vfnmsub213ss' {FMA|AVX512_F}. + kIdVfnmsub231pd, //!< Instruction 'vfnmsub231pd' {FMA|AVX512_F+VL}. + kIdVfnmsub231ph, //!< Instruction 'vfnmsub231ph' {AVX512_FP16+VL}. + kIdVfnmsub231ps, //!< Instruction 'vfnmsub231ps' {FMA|AVX512_F+VL}. + kIdVfnmsub231sd, //!< Instruction 'vfnmsub231sd' {FMA|AVX512_F}. + kIdVfnmsub231sh, //!< Instruction 'vfnmsub231sh' {AVX512_FP16}. + kIdVfnmsub231ss, //!< Instruction 'vfnmsub231ss' {FMA|AVX512_F}. + kIdVfnmsubpd, //!< Instruction 'vfnmsubpd' {FMA4}. + kIdVfnmsubps, //!< Instruction 'vfnmsubps' {FMA4}. + kIdVfnmsubsd, //!< Instruction 'vfnmsubsd' {FMA4}. + kIdVfnmsubss, //!< Instruction 'vfnmsubss' {FMA4}. + kIdVfpclasspd, //!< Instruction 'vfpclasspd' {AVX512_DQ+VL}. + kIdVfpclassph, //!< Instruction 'vfpclassph' {AVX512_FP16+VL}. + kIdVfpclassps, //!< Instruction 'vfpclassps' {AVX512_DQ+VL}. + kIdVfpclasssd, //!< Instruction 'vfpclasssd' {AVX512_DQ}. + kIdVfpclasssh, //!< Instruction 'vfpclasssh' {AVX512_FP16}. + kIdVfpclassss, //!< Instruction 'vfpclassss' {AVX512_DQ}. + kIdVfrczpd, //!< Instruction 'vfrczpd' {XOP}. + kIdVfrczps, //!< Instruction 'vfrczps' {XOP}. + kIdVfrczsd, //!< Instruction 'vfrczsd' {XOP}. + kIdVfrczss, //!< Instruction 'vfrczss' {XOP}. + kIdVgatherdpd, //!< Instruction 'vgatherdpd' {AVX2|AVX512_F+VL}. + kIdVgatherdps, //!< Instruction 'vgatherdps' {AVX2|AVX512_F+VL}. + kIdVgatherpf0dpd, //!< Instruction 'vgatherpf0dpd' {AVX512_PF}. + kIdVgatherpf0dps, //!< Instruction 'vgatherpf0dps' {AVX512_PF}. + kIdVgatherpf0qpd, //!< Instruction 'vgatherpf0qpd' {AVX512_PF}. + kIdVgatherpf0qps, //!< Instruction 'vgatherpf0qps' {AVX512_PF}. + kIdVgatherpf1dpd, //!< Instruction 'vgatherpf1dpd' {AVX512_PF}. + kIdVgatherpf1dps, //!< Instruction 'vgatherpf1dps' {AVX512_PF}. + kIdVgatherpf1qpd, //!< Instruction 'vgatherpf1qpd' {AVX512_PF}. + kIdVgatherpf1qps, //!< Instruction 'vgatherpf1qps' {AVX512_PF}. + kIdVgatherqpd, //!< Instruction 'vgatherqpd' {AVX2|AVX512_F+VL}. + kIdVgatherqps, //!< Instruction 'vgatherqps' {AVX2|AVX512_F+VL}. + kIdVgetexppd, //!< Instruction 'vgetexppd' {AVX512_F+VL}. + kIdVgetexpph, //!< Instruction 'vgetexpph' {AVX512_FP16+VL}. + kIdVgetexpps, //!< Instruction 'vgetexpps' {AVX512_F+VL}. + kIdVgetexpsd, //!< Instruction 'vgetexpsd' {AVX512_F}. + kIdVgetexpsh, //!< Instruction 'vgetexpsh' {AVX512_FP16}. + kIdVgetexpss, //!< Instruction 'vgetexpss' {AVX512_F}. + kIdVgetmantpd, //!< Instruction 'vgetmantpd' {AVX512_F+VL}. + kIdVgetmantph, //!< Instruction 'vgetmantph' {AVX512_FP16+VL}. + kIdVgetmantps, //!< Instruction 'vgetmantps' {AVX512_F+VL}. + kIdVgetmantsd, //!< Instruction 'vgetmantsd' {AVX512_F}. + kIdVgetmantsh, //!< Instruction 'vgetmantsh' {AVX512_FP16}. + kIdVgetmantss, //!< Instruction 'vgetmantss' {AVX512_F}. + kIdVgf2p8affineinvqb, //!< Instruction 'vgf2p8affineinvqb' {AVX|AVX512_F+VL & GFNI}. + kIdVgf2p8affineqb, //!< Instruction 'vgf2p8affineqb' {AVX|AVX512_F+VL & GFNI}. + kIdVgf2p8mulb, //!< Instruction 'vgf2p8mulb' {AVX|AVX512_F+VL & GFNI}. + kIdVhaddpd, //!< Instruction 'vhaddpd' {AVX}. + kIdVhaddps, //!< Instruction 'vhaddps' {AVX}. + kIdVhsubpd, //!< Instruction 'vhsubpd' {AVX}. + kIdVhsubps, //!< Instruction 'vhsubps' {AVX}. + kIdVinsertf128, //!< Instruction 'vinsertf128' {AVX}. + kIdVinsertf32x4, //!< Instruction 'vinsertf32x4' {AVX512_F+VL}. + kIdVinsertf32x8, //!< Instruction 'vinsertf32x8' {AVX512_DQ}. + kIdVinsertf64x2, //!< Instruction 'vinsertf64x2' {AVX512_DQ+VL}. + kIdVinsertf64x4, //!< Instruction 'vinsertf64x4' {AVX512_F}. + kIdVinserti128, //!< Instruction 'vinserti128' {AVX2}. + kIdVinserti32x4, //!< Instruction 'vinserti32x4' {AVX512_F+VL}. + kIdVinserti32x8, //!< Instruction 'vinserti32x8' {AVX512_DQ}. + kIdVinserti64x2, //!< Instruction 'vinserti64x2' {AVX512_DQ+VL}. + kIdVinserti64x4, //!< Instruction 'vinserti64x4' {AVX512_F}. + kIdVinsertps, //!< Instruction 'vinsertps' {AVX|AVX512_F}. + kIdVlddqu, //!< Instruction 'vlddqu' {AVX}. + kIdVldmxcsr, //!< Instruction 'vldmxcsr' {AVX}. + kIdVmaskmovdqu, //!< Instruction 'vmaskmovdqu' {AVX}. + kIdVmaskmovpd, //!< Instruction 'vmaskmovpd' {AVX}. + kIdVmaskmovps, //!< Instruction 'vmaskmovps' {AVX}. + kIdVmaxpd, //!< Instruction 'vmaxpd' {AVX|AVX512_F+VL}. + kIdVmaxph, //!< Instruction 'vmaxph' {AVX512_FP16+VL}. + kIdVmaxps, //!< Instruction 'vmaxps' {AVX|AVX512_F+VL}. + kIdVmaxsd, //!< Instruction 'vmaxsd' {AVX|AVX512_F}. + kIdVmaxsh, //!< Instruction 'vmaxsh' {AVX512_FP16}. + kIdVmaxss, //!< Instruction 'vmaxss' {AVX|AVX512_F}. + kIdVmcall, //!< Instruction 'vmcall' {VMX}. + kIdVmclear, //!< Instruction 'vmclear' {VMX}. + kIdVmfunc, //!< Instruction 'vmfunc' {VMX}. + kIdVmgexit, //!< Instruction 'vmgexit' {SEV_ES}. + kIdVminpd, //!< Instruction 'vminpd' {AVX|AVX512_F+VL}. + kIdVminph, //!< Instruction 'vminph' {AVX512_FP16+VL}. + kIdVminps, //!< Instruction 'vminps' {AVX|AVX512_F+VL}. + kIdVminsd, //!< Instruction 'vminsd' {AVX|AVX512_F}. + kIdVminsh, //!< Instruction 'vminsh' {AVX512_FP16}. + kIdVminss, //!< Instruction 'vminss' {AVX|AVX512_F}. + kIdVmlaunch, //!< Instruction 'vmlaunch' {VMX}. + kIdVmload, //!< Instruction 'vmload' {SVM}. + kIdVmmcall, //!< Instruction 'vmmcall' {SVM}. + kIdVmovapd, //!< Instruction 'vmovapd' {AVX|AVX512_F+VL}. + kIdVmovaps, //!< Instruction 'vmovaps' {AVX|AVX512_F+VL}. + kIdVmovd, //!< Instruction 'vmovd' {AVX|AVX512_F}. + kIdVmovddup, //!< Instruction 'vmovddup' {AVX|AVX512_F+VL}. + kIdVmovdqa, //!< Instruction 'vmovdqa' {AVX}. + kIdVmovdqa32, //!< Instruction 'vmovdqa32' {AVX512_F+VL}. + kIdVmovdqa64, //!< Instruction 'vmovdqa64' {AVX512_F+VL}. + kIdVmovdqu, //!< Instruction 'vmovdqu' {AVX}. + kIdVmovdqu16, //!< Instruction 'vmovdqu16' {AVX512_BW+VL}. + kIdVmovdqu32, //!< Instruction 'vmovdqu32' {AVX512_F+VL}. + kIdVmovdqu64, //!< Instruction 'vmovdqu64' {AVX512_F+VL}. + kIdVmovdqu8, //!< Instruction 'vmovdqu8' {AVX512_BW+VL}. + kIdVmovhlps, //!< Instruction 'vmovhlps' {AVX|AVX512_F}. + kIdVmovhpd, //!< Instruction 'vmovhpd' {AVX|AVX512_F}. + kIdVmovhps, //!< Instruction 'vmovhps' {AVX|AVX512_F}. + kIdVmovlhps, //!< Instruction 'vmovlhps' {AVX|AVX512_F}. + kIdVmovlpd, //!< Instruction 'vmovlpd' {AVX|AVX512_F}. + kIdVmovlps, //!< Instruction 'vmovlps' {AVX|AVX512_F}. + kIdVmovmskpd, //!< Instruction 'vmovmskpd' {AVX}. + kIdVmovmskps, //!< Instruction 'vmovmskps' {AVX}. + kIdVmovntdq, //!< Instruction 'vmovntdq' {AVX|AVX512_F+VL}. + kIdVmovntdqa, //!< Instruction 'vmovntdqa' {AVX|AVX2|AVX512_F+VL}. + kIdVmovntpd, //!< Instruction 'vmovntpd' {AVX|AVX512_F+VL}. + kIdVmovntps, //!< Instruction 'vmovntps' {AVX|AVX512_F+VL}. + kIdVmovq, //!< Instruction 'vmovq' {AVX|AVX512_F}. + kIdVmovsd, //!< Instruction 'vmovsd' {AVX|AVX512_F}. + kIdVmovsh, //!< Instruction 'vmovsh' {AVX512_FP16}. + kIdVmovshdup, //!< Instruction 'vmovshdup' {AVX|AVX512_F+VL}. + kIdVmovsldup, //!< Instruction 'vmovsldup' {AVX|AVX512_F+VL}. + kIdVmovss, //!< Instruction 'vmovss' {AVX|AVX512_F}. + kIdVmovupd, //!< Instruction 'vmovupd' {AVX|AVX512_F+VL}. + kIdVmovups, //!< Instruction 'vmovups' {AVX|AVX512_F+VL}. + kIdVmovw, //!< Instruction 'vmovw' {AVX512_FP16}. + kIdVmpsadbw, //!< Instruction 'vmpsadbw' {AVX|AVX2}. + kIdVmptrld, //!< Instruction 'vmptrld' {VMX}. + kIdVmptrst, //!< Instruction 'vmptrst' {VMX}. + kIdVmread, //!< Instruction 'vmread' {VMX}. + kIdVmresume, //!< Instruction 'vmresume' {VMX}. + kIdVmrun, //!< Instruction 'vmrun' {SVM}. + kIdVmsave, //!< Instruction 'vmsave' {SVM}. + kIdVmulpd, //!< Instruction 'vmulpd' {AVX|AVX512_F+VL}. + kIdVmulph, //!< Instruction 'vmulph' {AVX512_FP16+VL}. + kIdVmulps, //!< Instruction 'vmulps' {AVX|AVX512_F+VL}. + kIdVmulsd, //!< Instruction 'vmulsd' {AVX|AVX512_F}. + kIdVmulsh, //!< Instruction 'vmulsh' {AVX512_FP16}. + kIdVmulss, //!< Instruction 'vmulss' {AVX|AVX512_F}. + kIdVmwrite, //!< Instruction 'vmwrite' {VMX}. + kIdVmxoff, //!< Instruction 'vmxoff' {VMX}. + kIdVmxon, //!< Instruction 'vmxon' {VMX}. + kIdVorpd, //!< Instruction 'vorpd' {AVX|AVX512_DQ+VL}. + kIdVorps, //!< Instruction 'vorps' {AVX|AVX512_DQ+VL}. + kIdVp2intersectd, //!< Instruction 'vp2intersectd' {AVX512_VP2INTERSECT+VL}. + kIdVp2intersectq, //!< Instruction 'vp2intersectq' {AVX512_VP2INTERSECT+VL}. + kIdVp4dpwssd, //!< Instruction 'vp4dpwssd' {AVX512_4VNNIW}. + kIdVp4dpwssds, //!< Instruction 'vp4dpwssds' {AVX512_4VNNIW}. + kIdVpabsb, //!< Instruction 'vpabsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpabsd, //!< Instruction 'vpabsd' {AVX|AVX2|AVX512_F+VL}. + kIdVpabsq, //!< Instruction 'vpabsq' {AVX512_F+VL}. + kIdVpabsw, //!< Instruction 'vpabsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpackssdw, //!< Instruction 'vpackssdw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpacksswb, //!< Instruction 'vpacksswb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpackusdw, //!< Instruction 'vpackusdw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpackuswb, //!< Instruction 'vpackuswb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddb, //!< Instruction 'vpaddb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddd, //!< Instruction 'vpaddd' {AVX|AVX2|AVX512_F+VL}. + kIdVpaddq, //!< Instruction 'vpaddq' {AVX|AVX2|AVX512_F+VL}. + kIdVpaddsb, //!< Instruction 'vpaddsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddsw, //!< Instruction 'vpaddsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddusb, //!< Instruction 'vpaddusb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddusw, //!< Instruction 'vpaddusw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddw, //!< Instruction 'vpaddw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpalignr, //!< Instruction 'vpalignr' {AVX|AVX2|AVX512_BW+VL}. + kIdVpand, //!< Instruction 'vpand' {AVX|AVX2}. + kIdVpandd, //!< Instruction 'vpandd' {AVX512_F+VL}. + kIdVpandn, //!< Instruction 'vpandn' {AVX|AVX2}. + kIdVpandnd, //!< Instruction 'vpandnd' {AVX512_F+VL}. + kIdVpandnq, //!< Instruction 'vpandnq' {AVX512_F+VL}. + kIdVpandq, //!< Instruction 'vpandq' {AVX512_F+VL}. + kIdVpavgb, //!< Instruction 'vpavgb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpavgw, //!< Instruction 'vpavgw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpblendd, //!< Instruction 'vpblendd' {AVX2}. + kIdVpblendmb, //!< Instruction 'vpblendmb' {AVX512_BW+VL}. + kIdVpblendmd, //!< Instruction 'vpblendmd' {AVX512_F+VL}. + kIdVpblendmq, //!< Instruction 'vpblendmq' {AVX512_F+VL}. + kIdVpblendmw, //!< Instruction 'vpblendmw' {AVX512_BW+VL}. + kIdVpblendvb, //!< Instruction 'vpblendvb' {AVX|AVX2}. + kIdVpblendw, //!< Instruction 'vpblendw' {AVX|AVX2}. + kIdVpbroadcastb, //!< Instruction 'vpbroadcastb' {AVX2|AVX512_BW+VL}. + kIdVpbroadcastd, //!< Instruction 'vpbroadcastd' {AVX2|AVX512_F+VL}. + kIdVpbroadcastmb2q, //!< Instruction 'vpbroadcastmb2q' {AVX512_CD+VL}. + kIdVpbroadcastmw2d, //!< Instruction 'vpbroadcastmw2d' {AVX512_CD+VL}. + kIdVpbroadcastq, //!< Instruction 'vpbroadcastq' {AVX2|AVX512_F+VL}. + kIdVpbroadcastw, //!< Instruction 'vpbroadcastw' {AVX2|AVX512_BW+VL}. + kIdVpclmulqdq, //!< Instruction 'vpclmulqdq' {AVX|AVX512_F+VL & PCLMULQDQ|VPCLMULQDQ}. + kIdVpcmov, //!< Instruction 'vpcmov' {XOP}. + kIdVpcmpb, //!< Instruction 'vpcmpb' {AVX512_BW+VL}. + kIdVpcmpd, //!< Instruction 'vpcmpd' {AVX512_F+VL}. + kIdVpcmpeqb, //!< Instruction 'vpcmpeqb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpeqd, //!< Instruction 'vpcmpeqd' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpeqq, //!< Instruction 'vpcmpeqq' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpeqw, //!< Instruction 'vpcmpeqw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpestri, //!< Instruction 'vpcmpestri' {AVX}. + kIdVpcmpestrm, //!< Instruction 'vpcmpestrm' {AVX}. + kIdVpcmpgtb, //!< Instruction 'vpcmpgtb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpgtd, //!< Instruction 'vpcmpgtd' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpgtq, //!< Instruction 'vpcmpgtq' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpgtw, //!< Instruction 'vpcmpgtw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpistri, //!< Instruction 'vpcmpistri' {AVX}. + kIdVpcmpistrm, //!< Instruction 'vpcmpistrm' {AVX}. + kIdVpcmpq, //!< Instruction 'vpcmpq' {AVX512_F+VL}. + kIdVpcmpub, //!< Instruction 'vpcmpub' {AVX512_BW+VL}. + kIdVpcmpud, //!< Instruction 'vpcmpud' {AVX512_F+VL}. + kIdVpcmpuq, //!< Instruction 'vpcmpuq' {AVX512_F+VL}. + kIdVpcmpuw, //!< Instruction 'vpcmpuw' {AVX512_BW+VL}. + kIdVpcmpw, //!< Instruction 'vpcmpw' {AVX512_BW+VL}. + kIdVpcomb, //!< Instruction 'vpcomb' {XOP}. + kIdVpcomd, //!< Instruction 'vpcomd' {XOP}. + kIdVpcompressb, //!< Instruction 'vpcompressb' {AVX512_VBMI2+VL}. + kIdVpcompressd, //!< Instruction 'vpcompressd' {AVX512_F+VL}. + kIdVpcompressq, //!< Instruction 'vpcompressq' {AVX512_F+VL}. + kIdVpcompressw, //!< Instruction 'vpcompressw' {AVX512_VBMI2+VL}. + kIdVpcomq, //!< Instruction 'vpcomq' {XOP}. + kIdVpcomub, //!< Instruction 'vpcomub' {XOP}. + kIdVpcomud, //!< Instruction 'vpcomud' {XOP}. + kIdVpcomuq, //!< Instruction 'vpcomuq' {XOP}. + kIdVpcomuw, //!< Instruction 'vpcomuw' {XOP}. + kIdVpcomw, //!< Instruction 'vpcomw' {XOP}. + kIdVpconflictd, //!< Instruction 'vpconflictd' {AVX512_CD+VL}. + kIdVpconflictq, //!< Instruction 'vpconflictq' {AVX512_CD+VL}. + kIdVpdpbssd, //!< Instruction 'vpdpbssd' {AVX_VNNI_INT8}. + kIdVpdpbssds, //!< Instruction 'vpdpbssds' {AVX_VNNI_INT8}. + kIdVpdpbsud, //!< Instruction 'vpdpbsud' {AVX_VNNI_INT8}. + kIdVpdpbsuds, //!< Instruction 'vpdpbsuds' {AVX_VNNI_INT8}. + kIdVpdpbusd, //!< Instruction 'vpdpbusd' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpbusds, //!< Instruction 'vpdpbusds' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpbuud, //!< Instruction 'vpdpbuud' {AVX_VNNI_INT8}. + kIdVpdpbuuds, //!< Instruction 'vpdpbuuds' {AVX_VNNI_INT8}. + kIdVpdpwssd, //!< Instruction 'vpdpwssd' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpwssds, //!< Instruction 'vpdpwssds' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpwsud, //!< Instruction 'vpdpwsud' {AVX_VNNI_INT16}. + kIdVpdpwsuds, //!< Instruction 'vpdpwsuds' {AVX_VNNI_INT16}. + kIdVpdpwusd, //!< Instruction 'vpdpwusd' {AVX_VNNI_INT16}. + kIdVpdpwusds, //!< Instruction 'vpdpwusds' {AVX_VNNI_INT16}. + kIdVpdpwuud, //!< Instruction 'vpdpwuud' {AVX_VNNI_INT16}. + kIdVpdpwuuds, //!< Instruction 'vpdpwuuds' {AVX_VNNI_INT16}. + kIdVperm2f128, //!< Instruction 'vperm2f128' {AVX}. + kIdVperm2i128, //!< Instruction 'vperm2i128' {AVX2}. + kIdVpermb, //!< Instruction 'vpermb' {AVX512_VBMI+VL}. + kIdVpermd, //!< Instruction 'vpermd' {AVX2|AVX512_F+VL}. + kIdVpermi2b, //!< Instruction 'vpermi2b' {AVX512_VBMI+VL}. + kIdVpermi2d, //!< Instruction 'vpermi2d' {AVX512_F+VL}. + kIdVpermi2pd, //!< Instruction 'vpermi2pd' {AVX512_F+VL}. + kIdVpermi2ps, //!< Instruction 'vpermi2ps' {AVX512_F+VL}. + kIdVpermi2q, //!< Instruction 'vpermi2q' {AVX512_F+VL}. + kIdVpermi2w, //!< Instruction 'vpermi2w' {AVX512_BW+VL}. + kIdVpermil2pd, //!< Instruction 'vpermil2pd' {XOP}. + kIdVpermil2ps, //!< Instruction 'vpermil2ps' {XOP}. + kIdVpermilpd, //!< Instruction 'vpermilpd' {AVX|AVX512_F+VL}. + kIdVpermilps, //!< Instruction 'vpermilps' {AVX|AVX512_F+VL}. + kIdVpermpd, //!< Instruction 'vpermpd' {AVX2|AVX512_F+VL}. + kIdVpermps, //!< Instruction 'vpermps' {AVX2|AVX512_F+VL}. + kIdVpermq, //!< Instruction 'vpermq' {AVX2|AVX512_F+VL}. + kIdVpermt2b, //!< Instruction 'vpermt2b' {AVX512_VBMI+VL}. + kIdVpermt2d, //!< Instruction 'vpermt2d' {AVX512_F+VL}. + kIdVpermt2pd, //!< Instruction 'vpermt2pd' {AVX512_F+VL}. + kIdVpermt2ps, //!< Instruction 'vpermt2ps' {AVX512_F+VL}. + kIdVpermt2q, //!< Instruction 'vpermt2q' {AVX512_F+VL}. + kIdVpermt2w, //!< Instruction 'vpermt2w' {AVX512_BW+VL}. + kIdVpermw, //!< Instruction 'vpermw' {AVX512_BW+VL}. + kIdVpexpandb, //!< Instruction 'vpexpandb' {AVX512_VBMI2+VL}. + kIdVpexpandd, //!< Instruction 'vpexpandd' {AVX512_F+VL}. + kIdVpexpandq, //!< Instruction 'vpexpandq' {AVX512_F+VL}. + kIdVpexpandw, //!< Instruction 'vpexpandw' {AVX512_VBMI2+VL}. + kIdVpextrb, //!< Instruction 'vpextrb' {AVX|AVX512_BW}. + kIdVpextrd, //!< Instruction 'vpextrd' {AVX|AVX512_DQ}. + kIdVpextrq, //!< Instruction 'vpextrq' {AVX|AVX512_DQ} (X64). + kIdVpextrw, //!< Instruction 'vpextrw' {AVX|AVX512_BW}. + kIdVpgatherdd, //!< Instruction 'vpgatherdd' {AVX2|AVX512_F+VL}. + kIdVpgatherdq, //!< Instruction 'vpgatherdq' {AVX2|AVX512_F+VL}. + kIdVpgatherqd, //!< Instruction 'vpgatherqd' {AVX2|AVX512_F+VL}. + kIdVpgatherqq, //!< Instruction 'vpgatherqq' {AVX2|AVX512_F+VL}. + kIdVphaddbd, //!< Instruction 'vphaddbd' {XOP}. + kIdVphaddbq, //!< Instruction 'vphaddbq' {XOP}. + kIdVphaddbw, //!< Instruction 'vphaddbw' {XOP}. + kIdVphaddd, //!< Instruction 'vphaddd' {AVX|AVX2}. + kIdVphadddq, //!< Instruction 'vphadddq' {XOP}. + kIdVphaddsw, //!< Instruction 'vphaddsw' {AVX|AVX2}. + kIdVphaddubd, //!< Instruction 'vphaddubd' {XOP}. + kIdVphaddubq, //!< Instruction 'vphaddubq' {XOP}. + kIdVphaddubw, //!< Instruction 'vphaddubw' {XOP}. + kIdVphaddudq, //!< Instruction 'vphaddudq' {XOP}. + kIdVphadduwd, //!< Instruction 'vphadduwd' {XOP}. + kIdVphadduwq, //!< Instruction 'vphadduwq' {XOP}. + kIdVphaddw, //!< Instruction 'vphaddw' {AVX|AVX2}. + kIdVphaddwd, //!< Instruction 'vphaddwd' {XOP}. + kIdVphaddwq, //!< Instruction 'vphaddwq' {XOP}. + kIdVphminposuw, //!< Instruction 'vphminposuw' {AVX}. + kIdVphsubbw, //!< Instruction 'vphsubbw' {XOP}. + kIdVphsubd, //!< Instruction 'vphsubd' {AVX|AVX2}. + kIdVphsubdq, //!< Instruction 'vphsubdq' {XOP}. + kIdVphsubsw, //!< Instruction 'vphsubsw' {AVX|AVX2}. + kIdVphsubw, //!< Instruction 'vphsubw' {AVX|AVX2}. + kIdVphsubwd, //!< Instruction 'vphsubwd' {XOP}. + kIdVpinsrb, //!< Instruction 'vpinsrb' {AVX|AVX512_BW}. + kIdVpinsrd, //!< Instruction 'vpinsrd' {AVX|AVX512_DQ}. + kIdVpinsrq, //!< Instruction 'vpinsrq' {AVX|AVX512_DQ} (X64). + kIdVpinsrw, //!< Instruction 'vpinsrw' {AVX|AVX512_BW}. + kIdVplzcntd, //!< Instruction 'vplzcntd' {AVX512_CD+VL}. + kIdVplzcntq, //!< Instruction 'vplzcntq' {AVX512_CD+VL}. + kIdVpmacsdd, //!< Instruction 'vpmacsdd' {XOP}. + kIdVpmacsdqh, //!< Instruction 'vpmacsdqh' {XOP}. + kIdVpmacsdql, //!< Instruction 'vpmacsdql' {XOP}. + kIdVpmacssdd, //!< Instruction 'vpmacssdd' {XOP}. + kIdVpmacssdqh, //!< Instruction 'vpmacssdqh' {XOP}. + kIdVpmacssdql, //!< Instruction 'vpmacssdql' {XOP}. + kIdVpmacsswd, //!< Instruction 'vpmacsswd' {XOP}. + kIdVpmacssww, //!< Instruction 'vpmacssww' {XOP}. + kIdVpmacswd, //!< Instruction 'vpmacswd' {XOP}. + kIdVpmacsww, //!< Instruction 'vpmacsww' {XOP}. + kIdVpmadcsswd, //!< Instruction 'vpmadcsswd' {XOP}. + kIdVpmadcswd, //!< Instruction 'vpmadcswd' {XOP}. + kIdVpmadd52huq, //!< Instruction 'vpmadd52huq' {AVX_IFMA|AVX512_IFMA+VL}. + kIdVpmadd52luq, //!< Instruction 'vpmadd52luq' {AVX_IFMA|AVX512_IFMA+VL}. + kIdVpmaddubsw, //!< Instruction 'vpmaddubsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaddwd, //!< Instruction 'vpmaddwd' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaskmovd, //!< Instruction 'vpmaskmovd' {AVX2}. + kIdVpmaskmovq, //!< Instruction 'vpmaskmovq' {AVX2}. + kIdVpmaxsb, //!< Instruction 'vpmaxsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaxsd, //!< Instruction 'vpmaxsd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmaxsq, //!< Instruction 'vpmaxsq' {AVX512_F+VL}. + kIdVpmaxsw, //!< Instruction 'vpmaxsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaxub, //!< Instruction 'vpmaxub' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaxud, //!< Instruction 'vpmaxud' {AVX|AVX2|AVX512_F+VL}. + kIdVpmaxuq, //!< Instruction 'vpmaxuq' {AVX512_F+VL}. + kIdVpmaxuw, //!< Instruction 'vpmaxuw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminsb, //!< Instruction 'vpminsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminsd, //!< Instruction 'vpminsd' {AVX|AVX2|AVX512_F+VL}. + kIdVpminsq, //!< Instruction 'vpminsq' {AVX512_F+VL}. + kIdVpminsw, //!< Instruction 'vpminsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminub, //!< Instruction 'vpminub' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminud, //!< Instruction 'vpminud' {AVX|AVX2|AVX512_F+VL}. + kIdVpminuq, //!< Instruction 'vpminuq' {AVX512_F+VL}. + kIdVpminuw, //!< Instruction 'vpminuw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmovb2m, //!< Instruction 'vpmovb2m' {AVX512_BW+VL}. + kIdVpmovd2m, //!< Instruction 'vpmovd2m' {AVX512_DQ+VL}. + kIdVpmovdb, //!< Instruction 'vpmovdb' {AVX512_F+VL}. + kIdVpmovdw, //!< Instruction 'vpmovdw' {AVX512_F+VL}. + kIdVpmovm2b, //!< Instruction 'vpmovm2b' {AVX512_BW+VL}. + kIdVpmovm2d, //!< Instruction 'vpmovm2d' {AVX512_DQ+VL}. + kIdVpmovm2q, //!< Instruction 'vpmovm2q' {AVX512_DQ+VL}. + kIdVpmovm2w, //!< Instruction 'vpmovm2w' {AVX512_BW+VL}. + kIdVpmovmskb, //!< Instruction 'vpmovmskb' {AVX|AVX2}. + kIdVpmovq2m, //!< Instruction 'vpmovq2m' {AVX512_DQ+VL}. + kIdVpmovqb, //!< Instruction 'vpmovqb' {AVX512_F+VL}. + kIdVpmovqd, //!< Instruction 'vpmovqd' {AVX512_F+VL}. + kIdVpmovqw, //!< Instruction 'vpmovqw' {AVX512_F+VL}. + kIdVpmovsdb, //!< Instruction 'vpmovsdb' {AVX512_F+VL}. + kIdVpmovsdw, //!< Instruction 'vpmovsdw' {AVX512_F+VL}. + kIdVpmovsqb, //!< Instruction 'vpmovsqb' {AVX512_F+VL}. + kIdVpmovsqd, //!< Instruction 'vpmovsqd' {AVX512_F+VL}. + kIdVpmovsqw, //!< Instruction 'vpmovsqw' {AVX512_F+VL}. + kIdVpmovswb, //!< Instruction 'vpmovswb' {AVX512_BW+VL}. + kIdVpmovsxbd, //!< Instruction 'vpmovsxbd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxbq, //!< Instruction 'vpmovsxbq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxbw, //!< Instruction 'vpmovsxbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmovsxdq, //!< Instruction 'vpmovsxdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxwd, //!< Instruction 'vpmovsxwd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxwq, //!< Instruction 'vpmovsxwq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovusdb, //!< Instruction 'vpmovusdb' {AVX512_F+VL}. + kIdVpmovusdw, //!< Instruction 'vpmovusdw' {AVX512_F+VL}. + kIdVpmovusqb, //!< Instruction 'vpmovusqb' {AVX512_F+VL}. + kIdVpmovusqd, //!< Instruction 'vpmovusqd' {AVX512_F+VL}. + kIdVpmovusqw, //!< Instruction 'vpmovusqw' {AVX512_F+VL}. + kIdVpmovuswb, //!< Instruction 'vpmovuswb' {AVX512_BW+VL}. + kIdVpmovw2m, //!< Instruction 'vpmovw2m' {AVX512_BW+VL}. + kIdVpmovwb, //!< Instruction 'vpmovwb' {AVX512_BW+VL}. + kIdVpmovzxbd, //!< Instruction 'vpmovzxbd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxbq, //!< Instruction 'vpmovzxbq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxbw, //!< Instruction 'vpmovzxbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmovzxdq, //!< Instruction 'vpmovzxdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxwd, //!< Instruction 'vpmovzxwd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxwq, //!< Instruction 'vpmovzxwq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmuldq, //!< Instruction 'vpmuldq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmulhrsw, //!< Instruction 'vpmulhrsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmulhuw, //!< Instruction 'vpmulhuw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmulhw, //!< Instruction 'vpmulhw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmulld, //!< Instruction 'vpmulld' {AVX|AVX2|AVX512_F+VL}. + kIdVpmullq, //!< Instruction 'vpmullq' {AVX512_DQ+VL}. + kIdVpmullw, //!< Instruction 'vpmullw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmultishiftqb, //!< Instruction 'vpmultishiftqb' {AVX512_VBMI+VL}. + kIdVpmuludq, //!< Instruction 'vpmuludq' {AVX|AVX2|AVX512_F+VL}. + kIdVpopcntb, //!< Instruction 'vpopcntb' {AVX512_BITALG+VL}. + kIdVpopcntd, //!< Instruction 'vpopcntd' {AVX512_VPOPCNTDQ+VL}. + kIdVpopcntq, //!< Instruction 'vpopcntq' {AVX512_VPOPCNTDQ+VL}. + kIdVpopcntw, //!< Instruction 'vpopcntw' {AVX512_BITALG+VL}. + kIdVpor, //!< Instruction 'vpor' {AVX|AVX2}. + kIdVpord, //!< Instruction 'vpord' {AVX512_F+VL}. + kIdVporq, //!< Instruction 'vporq' {AVX512_F+VL}. + kIdVpperm, //!< Instruction 'vpperm' {XOP}. + kIdVprold, //!< Instruction 'vprold' {AVX512_F+VL}. + kIdVprolq, //!< Instruction 'vprolq' {AVX512_F+VL}. + kIdVprolvd, //!< Instruction 'vprolvd' {AVX512_F+VL}. + kIdVprolvq, //!< Instruction 'vprolvq' {AVX512_F+VL}. + kIdVprord, //!< Instruction 'vprord' {AVX512_F+VL}. + kIdVprorq, //!< Instruction 'vprorq' {AVX512_F+VL}. + kIdVprorvd, //!< Instruction 'vprorvd' {AVX512_F+VL}. + kIdVprorvq, //!< Instruction 'vprorvq' {AVX512_F+VL}. + kIdVprotb, //!< Instruction 'vprotb' {XOP}. + kIdVprotd, //!< Instruction 'vprotd' {XOP}. + kIdVprotq, //!< Instruction 'vprotq' {XOP}. + kIdVprotw, //!< Instruction 'vprotw' {XOP}. + kIdVpsadbw, //!< Instruction 'vpsadbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpscatterdd, //!< Instruction 'vpscatterdd' {AVX512_F+VL}. + kIdVpscatterdq, //!< Instruction 'vpscatterdq' {AVX512_F+VL}. + kIdVpscatterqd, //!< Instruction 'vpscatterqd' {AVX512_F+VL}. + kIdVpscatterqq, //!< Instruction 'vpscatterqq' {AVX512_F+VL}. + kIdVpshab, //!< Instruction 'vpshab' {XOP}. + kIdVpshad, //!< Instruction 'vpshad' {XOP}. + kIdVpshaq, //!< Instruction 'vpshaq' {XOP}. + kIdVpshaw, //!< Instruction 'vpshaw' {XOP}. + kIdVpshlb, //!< Instruction 'vpshlb' {XOP}. + kIdVpshld, //!< Instruction 'vpshld' {XOP}. + kIdVpshldd, //!< Instruction 'vpshldd' {AVX512_VBMI2+VL}. + kIdVpshldq, //!< Instruction 'vpshldq' {AVX512_VBMI2+VL}. + kIdVpshldvd, //!< Instruction 'vpshldvd' {AVX512_VBMI2+VL}. + kIdVpshldvq, //!< Instruction 'vpshldvq' {AVX512_VBMI2+VL}. + kIdVpshldvw, //!< Instruction 'vpshldvw' {AVX512_VBMI2+VL}. + kIdVpshldw, //!< Instruction 'vpshldw' {AVX512_VBMI2+VL}. + kIdVpshlq, //!< Instruction 'vpshlq' {XOP}. + kIdVpshlw, //!< Instruction 'vpshlw' {XOP}. + kIdVpshrdd, //!< Instruction 'vpshrdd' {AVX512_VBMI2+VL}. + kIdVpshrdq, //!< Instruction 'vpshrdq' {AVX512_VBMI2+VL}. + kIdVpshrdvd, //!< Instruction 'vpshrdvd' {AVX512_VBMI2+VL}. + kIdVpshrdvq, //!< Instruction 'vpshrdvq' {AVX512_VBMI2+VL}. + kIdVpshrdvw, //!< Instruction 'vpshrdvw' {AVX512_VBMI2+VL}. + kIdVpshrdw, //!< Instruction 'vpshrdw' {AVX512_VBMI2+VL}. + kIdVpshufb, //!< Instruction 'vpshufb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpshufbitqmb, //!< Instruction 'vpshufbitqmb' {AVX512_BITALG+VL}. + kIdVpshufd, //!< Instruction 'vpshufd' {AVX|AVX2|AVX512_F+VL}. + kIdVpshufhw, //!< Instruction 'vpshufhw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpshuflw, //!< Instruction 'vpshuflw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsignb, //!< Instruction 'vpsignb' {AVX|AVX2}. + kIdVpsignd, //!< Instruction 'vpsignd' {AVX|AVX2}. + kIdVpsignw, //!< Instruction 'vpsignw' {AVX|AVX2}. + kIdVpslld, //!< Instruction 'vpslld' {AVX|AVX2|AVX512_F+VL}. + kIdVpslldq, //!< Instruction 'vpslldq' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsllq, //!< Instruction 'vpsllq' {AVX|AVX2|AVX512_F+VL}. + kIdVpsllvd, //!< Instruction 'vpsllvd' {AVX2|AVX512_F+VL}. + kIdVpsllvq, //!< Instruction 'vpsllvq' {AVX2|AVX512_F+VL}. + kIdVpsllvw, //!< Instruction 'vpsllvw' {AVX512_BW+VL}. + kIdVpsllw, //!< Instruction 'vpsllw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsrad, //!< Instruction 'vpsrad' {AVX|AVX2|AVX512_F+VL}. + kIdVpsraq, //!< Instruction 'vpsraq' {AVX512_F+VL}. + kIdVpsravd, //!< Instruction 'vpsravd' {AVX2|AVX512_F+VL}. + kIdVpsravq, //!< Instruction 'vpsravq' {AVX512_F+VL}. + kIdVpsravw, //!< Instruction 'vpsravw' {AVX512_BW+VL}. + kIdVpsraw, //!< Instruction 'vpsraw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsrld, //!< Instruction 'vpsrld' {AVX|AVX2|AVX512_F+VL}. + kIdVpsrldq, //!< Instruction 'vpsrldq' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsrlq, //!< Instruction 'vpsrlq' {AVX|AVX2|AVX512_F+VL}. + kIdVpsrlvd, //!< Instruction 'vpsrlvd' {AVX2|AVX512_F+VL}. + kIdVpsrlvq, //!< Instruction 'vpsrlvq' {AVX2|AVX512_F+VL}. + kIdVpsrlvw, //!< Instruction 'vpsrlvw' {AVX512_BW+VL}. + kIdVpsrlw, //!< Instruction 'vpsrlw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubb, //!< Instruction 'vpsubb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubd, //!< Instruction 'vpsubd' {AVX|AVX2|AVX512_F+VL}. + kIdVpsubq, //!< Instruction 'vpsubq' {AVX|AVX2|AVX512_F+VL}. + kIdVpsubsb, //!< Instruction 'vpsubsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubsw, //!< Instruction 'vpsubsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubusb, //!< Instruction 'vpsubusb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubusw, //!< Instruction 'vpsubusw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubw, //!< Instruction 'vpsubw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpternlogd, //!< Instruction 'vpternlogd' {AVX512_F+VL}. + kIdVpternlogq, //!< Instruction 'vpternlogq' {AVX512_F+VL}. + kIdVptest, //!< Instruction 'vptest' {AVX}. + kIdVptestmb, //!< Instruction 'vptestmb' {AVX512_BW+VL}. + kIdVptestmd, //!< Instruction 'vptestmd' {AVX512_F+VL}. + kIdVptestmq, //!< Instruction 'vptestmq' {AVX512_F+VL}. + kIdVptestmw, //!< Instruction 'vptestmw' {AVX512_BW+VL}. + kIdVptestnmb, //!< Instruction 'vptestnmb' {AVX512_BW+VL}. + kIdVptestnmd, //!< Instruction 'vptestnmd' {AVX512_F+VL}. + kIdVptestnmq, //!< Instruction 'vptestnmq' {AVX512_F+VL}. + kIdVptestnmw, //!< Instruction 'vptestnmw' {AVX512_BW+VL}. + kIdVpunpckhbw, //!< Instruction 'vpunpckhbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpunpckhdq, //!< Instruction 'vpunpckhdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpckhqdq, //!< Instruction 'vpunpckhqdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpckhwd, //!< Instruction 'vpunpckhwd' {AVX|AVX2|AVX512_BW+VL}. + kIdVpunpcklbw, //!< Instruction 'vpunpcklbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpunpckldq, //!< Instruction 'vpunpckldq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpcklqdq, //!< Instruction 'vpunpcklqdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpcklwd, //!< Instruction 'vpunpcklwd' {AVX|AVX2|AVX512_BW+VL}. + kIdVpxor, //!< Instruction 'vpxor' {AVX|AVX2}. + kIdVpxord, //!< Instruction 'vpxord' {AVX512_F+VL}. + kIdVpxorq, //!< Instruction 'vpxorq' {AVX512_F+VL}. + kIdVrangepd, //!< Instruction 'vrangepd' {AVX512_DQ+VL}. + kIdVrangeps, //!< Instruction 'vrangeps' {AVX512_DQ+VL}. + kIdVrangesd, //!< Instruction 'vrangesd' {AVX512_DQ}. + kIdVrangess, //!< Instruction 'vrangess' {AVX512_DQ}. + kIdVrcp14pd, //!< Instruction 'vrcp14pd' {AVX512_F+VL}. + kIdVrcp14ps, //!< Instruction 'vrcp14ps' {AVX512_F+VL}. + kIdVrcp14sd, //!< Instruction 'vrcp14sd' {AVX512_F}. + kIdVrcp14ss, //!< Instruction 'vrcp14ss' {AVX512_F}. + kIdVrcp28pd, //!< Instruction 'vrcp28pd' {AVX512_ER}. + kIdVrcp28ps, //!< Instruction 'vrcp28ps' {AVX512_ER}. + kIdVrcp28sd, //!< Instruction 'vrcp28sd' {AVX512_ER}. + kIdVrcp28ss, //!< Instruction 'vrcp28ss' {AVX512_ER}. + kIdVrcpph, //!< Instruction 'vrcpph' {AVX512_FP16}. + kIdVrcpps, //!< Instruction 'vrcpps' {AVX}. + kIdVrcpsh, //!< Instruction 'vrcpsh' {AVX512_FP16}. + kIdVrcpss, //!< Instruction 'vrcpss' {AVX}. + kIdVreducepd, //!< Instruction 'vreducepd' {AVX512_DQ+VL}. + kIdVreduceph, //!< Instruction 'vreduceph' {AVX512_FP16+VL}. + kIdVreduceps, //!< Instruction 'vreduceps' {AVX512_DQ+VL}. + kIdVreducesd, //!< Instruction 'vreducesd' {AVX512_DQ}. + kIdVreducesh, //!< Instruction 'vreducesh' {AVX512_FP16}. + kIdVreducess, //!< Instruction 'vreducess' {AVX512_DQ}. + kIdVrndscalepd, //!< Instruction 'vrndscalepd' {AVX512_F+VL}. + kIdVrndscaleph, //!< Instruction 'vrndscaleph' {AVX512_FP16+VL}. + kIdVrndscaleps, //!< Instruction 'vrndscaleps' {AVX512_F+VL}. + kIdVrndscalesd, //!< Instruction 'vrndscalesd' {AVX512_F}. + kIdVrndscalesh, //!< Instruction 'vrndscalesh' {AVX512_FP16}. + kIdVrndscaless, //!< Instruction 'vrndscaless' {AVX512_F}. + kIdVroundpd, //!< Instruction 'vroundpd' {AVX}. + kIdVroundps, //!< Instruction 'vroundps' {AVX}. + kIdVroundsd, //!< Instruction 'vroundsd' {AVX}. + kIdVroundss, //!< Instruction 'vroundss' {AVX}. + kIdVrsqrt14pd, //!< Instruction 'vrsqrt14pd' {AVX512_F+VL}. + kIdVrsqrt14ps, //!< Instruction 'vrsqrt14ps' {AVX512_F+VL}. + kIdVrsqrt14sd, //!< Instruction 'vrsqrt14sd' {AVX512_F}. + kIdVrsqrt14ss, //!< Instruction 'vrsqrt14ss' {AVX512_F}. + kIdVrsqrt28pd, //!< Instruction 'vrsqrt28pd' {AVX512_ER}. + kIdVrsqrt28ps, //!< Instruction 'vrsqrt28ps' {AVX512_ER}. + kIdVrsqrt28sd, //!< Instruction 'vrsqrt28sd' {AVX512_ER}. + kIdVrsqrt28ss, //!< Instruction 'vrsqrt28ss' {AVX512_ER}. + kIdVrsqrtph, //!< Instruction 'vrsqrtph' {AVX512_FP16+VL}. + kIdVrsqrtps, //!< Instruction 'vrsqrtps' {AVX}. + kIdVrsqrtsh, //!< Instruction 'vrsqrtsh' {AVX512_FP16}. + kIdVrsqrtss, //!< Instruction 'vrsqrtss' {AVX}. + kIdVscalefpd, //!< Instruction 'vscalefpd' {AVX512_F+VL}. + kIdVscalefph, //!< Instruction 'vscalefph' {AVX512_FP16+VL}. + kIdVscalefps, //!< Instruction 'vscalefps' {AVX512_F+VL}. + kIdVscalefsd, //!< Instruction 'vscalefsd' {AVX512_F}. + kIdVscalefsh, //!< Instruction 'vscalefsh' {AVX512_FP16}. + kIdVscalefss, //!< Instruction 'vscalefss' {AVX512_F}. + kIdVscatterdpd, //!< Instruction 'vscatterdpd' {AVX512_F+VL}. + kIdVscatterdps, //!< Instruction 'vscatterdps' {AVX512_F+VL}. + kIdVscatterpf0dpd, //!< Instruction 'vscatterpf0dpd' {AVX512_PF}. + kIdVscatterpf0dps, //!< Instruction 'vscatterpf0dps' {AVX512_PF}. + kIdVscatterpf0qpd, //!< Instruction 'vscatterpf0qpd' {AVX512_PF}. + kIdVscatterpf0qps, //!< Instruction 'vscatterpf0qps' {AVX512_PF}. + kIdVscatterpf1dpd, //!< Instruction 'vscatterpf1dpd' {AVX512_PF}. + kIdVscatterpf1dps, //!< Instruction 'vscatterpf1dps' {AVX512_PF}. + kIdVscatterpf1qpd, //!< Instruction 'vscatterpf1qpd' {AVX512_PF}. + kIdVscatterpf1qps, //!< Instruction 'vscatterpf1qps' {AVX512_PF}. + kIdVscatterqpd, //!< Instruction 'vscatterqpd' {AVX512_F+VL}. + kIdVscatterqps, //!< Instruction 'vscatterqps' {AVX512_F+VL}. + kIdVsha512msg1, //!< Instruction 'vsha512msg1' {AVX & SHA512}. + kIdVsha512msg2, //!< Instruction 'vsha512msg2' {AVX & SHA512}. + kIdVsha512rnds2, //!< Instruction 'vsha512rnds2' {AVX & SHA512}. + kIdVshuff32x4, //!< Instruction 'vshuff32x4' {AVX512_F+VL}. + kIdVshuff64x2, //!< Instruction 'vshuff64x2' {AVX512_F+VL}. + kIdVshufi32x4, //!< Instruction 'vshufi32x4' {AVX512_F+VL}. + kIdVshufi64x2, //!< Instruction 'vshufi64x2' {AVX512_F+VL}. + kIdVshufpd, //!< Instruction 'vshufpd' {AVX|AVX512_F+VL}. + kIdVshufps, //!< Instruction 'vshufps' {AVX|AVX512_F+VL}. + kIdVsm3msg1, //!< Instruction 'vsm3msg1' {AVX & SM3}. + kIdVsm3msg2, //!< Instruction 'vsm3msg2' {AVX & SM3}. + kIdVsm3rnds2, //!< Instruction 'vsm3rnds2' {AVX & SM3}. + kIdVsm4key4, //!< Instruction 'vsm4key4' {AVX & SM4}. + kIdVsm4rnds4, //!< Instruction 'vsm4rnds4' {AVX & SM4}. + kIdVsqrtpd, //!< Instruction 'vsqrtpd' {AVX|AVX512_F+VL}. + kIdVsqrtph, //!< Instruction 'vsqrtph' {AVX512_FP16+VL}. + kIdVsqrtps, //!< Instruction 'vsqrtps' {AVX|AVX512_F+VL}. + kIdVsqrtsd, //!< Instruction 'vsqrtsd' {AVX|AVX512_F}. + kIdVsqrtsh, //!< Instruction 'vsqrtsh' {AVX512_FP16}. + kIdVsqrtss, //!< Instruction 'vsqrtss' {AVX|AVX512_F}. + kIdVstmxcsr, //!< Instruction 'vstmxcsr' {AVX}. + kIdVsubpd, //!< Instruction 'vsubpd' {AVX|AVX512_F+VL}. + kIdVsubph, //!< Instruction 'vsubph' {AVX512_FP16+VL}. + kIdVsubps, //!< Instruction 'vsubps' {AVX|AVX512_F+VL}. + kIdVsubsd, //!< Instruction 'vsubsd' {AVX|AVX512_F}. + kIdVsubsh, //!< Instruction 'vsubsh' {AVX512_FP16}. + kIdVsubss, //!< Instruction 'vsubss' {AVX|AVX512_F}. + kIdVtestpd, //!< Instruction 'vtestpd' {AVX}. + kIdVtestps, //!< Instruction 'vtestps' {AVX}. + kIdVucomisd, //!< Instruction 'vucomisd' {AVX|AVX512_F}. + kIdVucomish, //!< Instruction 'vucomish' {AVX512_FP16}. + kIdVucomiss, //!< Instruction 'vucomiss' {AVX|AVX512_F}. + kIdVunpckhpd, //!< Instruction 'vunpckhpd' {AVX|AVX512_F+VL}. + kIdVunpckhps, //!< Instruction 'vunpckhps' {AVX|AVX512_F+VL}. + kIdVunpcklpd, //!< Instruction 'vunpcklpd' {AVX|AVX512_F+VL}. + kIdVunpcklps, //!< Instruction 'vunpcklps' {AVX|AVX512_F+VL}. + kIdVxorpd, //!< Instruction 'vxorpd' {AVX|AVX512_DQ+VL}. + kIdVxorps, //!< Instruction 'vxorps' {AVX|AVX512_DQ+VL}. + kIdVzeroall, //!< Instruction 'vzeroall' {AVX}. + kIdVzeroupper, //!< Instruction 'vzeroupper' {AVX}. + kIdWbinvd, //!< Instruction 'wbinvd' {I486}. + kIdWbnoinvd, //!< Instruction 'wbnoinvd' {WBNOINVD}. + kIdWrfsbase, //!< Instruction 'wrfsbase' {FSGSBASE} (X64). + kIdWrgsbase, //!< Instruction 'wrgsbase' {FSGSBASE} (X64). + kIdWrmsr, //!< Instruction 'wrmsr' {MSR}. + kIdWrssd, //!< Instruction 'wrssd' {CET_SS}. + kIdWrssq, //!< Instruction 'wrssq' {CET_SS} (X64). + kIdWrussd, //!< Instruction 'wrussd' {CET_SS}. + kIdWrussq, //!< Instruction 'wrussq' {CET_SS} (X64). + kIdXabort, //!< Instruction 'xabort' {RTM}. + kIdXadd, //!< Instruction 'xadd' {I486}. + kIdXbegin, //!< Instruction 'xbegin' {RTM}. + kIdXchg, //!< Instruction 'xchg'. + kIdXend, //!< Instruction 'xend' {RTM}. + kIdXgetbv, //!< Instruction 'xgetbv' {XSAVE}. + kIdXlatb, //!< Instruction 'xlatb'. + kIdXor, //!< Instruction 'xor'. + kIdXorpd, //!< Instruction 'xorpd' {SSE2}. + kIdXorps, //!< Instruction 'xorps' {SSE}. + kIdXresldtrk, //!< Instruction 'xresldtrk' {TSXLDTRK}. + kIdXrstor, //!< Instruction 'xrstor' {XSAVE}. + kIdXrstor64, //!< Instruction 'xrstor64' {XSAVE} (X64). + kIdXrstors, //!< Instruction 'xrstors' {XSAVES}. + kIdXrstors64, //!< Instruction 'xrstors64' {XSAVES} (X64). + kIdXsave, //!< Instruction 'xsave' {XSAVE}. + kIdXsave64, //!< Instruction 'xsave64' {XSAVE} (X64). + kIdXsavec, //!< Instruction 'xsavec' {XSAVEC}. + kIdXsavec64, //!< Instruction 'xsavec64' {XSAVEC} (X64). + kIdXsaveopt, //!< Instruction 'xsaveopt' {XSAVEOPT}. + kIdXsaveopt64, //!< Instruction 'xsaveopt64' {XSAVEOPT} (X64). + kIdXsaves, //!< Instruction 'xsaves' {XSAVES}. + kIdXsaves64, //!< Instruction 'xsaves64' {XSAVES} (X64). + kIdXsetbv, //!< Instruction 'xsetbv' {XSAVE}. + kIdXsusldtrk, //!< Instruction 'xsusldtrk' {TSXLDTRK}. + kIdXtest, //!< Instruction 'xtest' {TSX}. + _kIdCount + // ${InstId:End} + }; + + //! Tests whether the `instId` is defined. + static ASMJIT_INLINE_NODEBUG constexpr bool isDefinedId(InstId instId) noexcept { return instId < _kIdCount; } + + //! \cond + #define ASMJIT_INST_FROM_COND(ID) \ + ID##o, ID##no, ID##b , ID##ae, \ + ID##e, ID##ne, ID##be, ID##a , \ + ID##s, ID##ns, ID##pe, ID##po, \ + ID##l, ID##ge, ID##le, ID##g + + static constexpr uint16_t _jccTable[] = { ASMJIT_INST_FROM_COND(Inst::kIdJ) }; + static constexpr uint16_t _setccTable[] = { ASMJIT_INST_FROM_COND(Inst::kIdSet) }; + static constexpr uint16_t _cmovccTable[] = { ASMJIT_INST_FROM_COND(Inst::kIdCmov) }; + + #undef ASMJIT_INST_FROM_COND + //! \endcond + + //! Translates a condition code `cond` to a `jcc` instruction id. + static ASMJIT_INLINE_NODEBUG constexpr InstId jccFromCond(CondCode cond) noexcept { return _jccTable[uint8_t(cond)]; } + //! Translates a condition code `cond` to a `setcc` instruction id. + static ASMJIT_INLINE_NODEBUG constexpr InstId setccFromCond(CondCode cond) noexcept { return _setccTable[uint8_t(cond)]; } + //! Translates a condition code `cond` to a `cmovcc` instruction id. + static ASMJIT_INLINE_NODEBUG constexpr InstId cmovccFromCond(CondCode cond) noexcept { return _cmovccTable[uint8_t(cond)]; } +} // {Inst} + +//! FPU status word bits. +enum class FpuStatusWord : uint16_t { + kNone = 0x0000u, //!< No bits set. + + kInvalid = 0x0001u, //!< Invalid operation. + kDenormalized = 0x0002u, //!< Denormalized operand. + kDivByZero = 0x0004u, //!< Division by zero. + kOverflow = 0x0008u, //!< Overflown. + kUnderflow = 0x0010u, //!< Underflown. + kPrecision = 0x0020u, //!< Precision lost. + kStackFault = 0x0040u, //!< Stack fault. + kInterrupt = 0x0080u, //!< Interrupt. + kC0 = 0x0100u, //!< C0 flag. + kC1 = 0x0200u, //!< C1 flag. + kC2 = 0x0400u, //!< C2 flag. + kTopMask = 0x3800u, //!< Top of the stack (mask). + kC3 = 0x4000u, //!< C3 flag. + kBusy = 0x8000u //!< FPU is busy. +}; +ASMJIT_DEFINE_ENUM_FLAGS(FpuStatusWord) + +//! FPU control word bits. +enum class FpuControlWord : uint16_t { + kNone = 0x0000u, //!< No bits set. + + // Bits 0-5 + // -------- + + kEM_Mask = 0x003Fu, //!< Exception mask (0x3F). + kEM_Invalid = 0x0001u, //!< Invalid operation exception. + kEM_Denormal = 0x0002u, //!< Denormalized operand exception. + kEM_DivByZero = 0x0004u, //!< Division by zero exception. + kEM_Overflow = 0x0008u, //!< Overflow exception. + kEM_Underflow = 0x0010u, //!< Underflow exception. + kEM_Inexact = 0x0020u, //!< Inexact operation exception. + + // Bits 8-9 + // -------- + + kPC_Mask = 0x0300u, //!< Precision control mask. + kPC_Float = 0x0000u, //!< Single precision (24 bits). + kPC_Reserved = 0x0100u, //!< Reserved. + kPC_Double = 0x0200u, //!< Double precision (53 bits). + kPC_Extended = 0x0300u, //!< Extended precision (64 bits). + + // Bits 10-11 + // ---------- + + kRC_Mask = 0x0C00u, //!< Rounding control mask. + kRC_Nearest = 0x0000u, //!< Round to nearest even. + kRC_Down = 0x0400u, //!< Round down (floor). + kRC_Up = 0x0800u, //!< Round up (ceil). + kRC_Truncate = 0x0C00u, //!< Round towards zero (truncate). + + // Bit 12 + // ------ + + kIC_Mask = 0x1000u, //!< Infinity control. + kIC_Projective = 0x0000u, //!< Projective (not supported on X64). + kIC_Affine = 0x1000u //!< Affine (default). +}; +ASMJIT_DEFINE_ENUM_FLAGS(FpuControlWord) + +//! An immediate value that can be used with CMP[PD|PS|SD|SS] instructions. +enum class CmpImm : uint8_t { + kEQ = 0x00u, //!< Equal (Quiet), same as \ref VCmpImm::kEQ_OQ. + kLT = 0x01u, //!< Less (Signaling), same as \ref VCmpImm::kLT_OS. + kLE = 0x02u, //!< Less/Equal (Signaling), same as \ref VCmpImm::kLE_OS. + kUNORD = 0x03u, //!< Unordered (Quiet), same as \ref VCmpImm::kUNORD_Q. + kNEQ = 0x04u, //!< Not Equal (Quiet), same as \ref VCmpImm::kNEQ_UQ. + kNLT = 0x05u, //!< Not Less (Signaling), same as \ref VCmpImm::kNLT_US. + kNLE = 0x06u, //!< Not Less/Equal (Signaling), same as \ref VCmpImm::kNLE_US. + kORD = 0x07u //!< Ordered (Quiet), same as \ref VCmpImm::kORD_Q. +}; + +//! An immediate value that can be used with [V]PCMP[I|E]STR[I|M] instructions. +enum class PCmpStrImm : uint8_t { + // Source Data Format + // ------------------ + + kUB = 0x00u << 0, //!< The source data format is unsigned bytes. + kUW = 0x01u << 0, //!< The source data format is unsigned words. + kSB = 0x02u << 0, //!< The source data format is signed bytes. + kSW = 0x03u << 0, //!< The source data format is signed words. + + // Aggregation Operation + // --------------------- + + kEqualAny = 0x00u << 2, //!< The arithmetic comparison is "equal". + kRanges = 0x01u << 2, //!< The arithmetic comparison is "greater than or equal" between even indexed + //!< elements and "less than or equal" between odd indexed elements. + kEqualEach = 0x02u << 2, //!< The arithmetic comparison is "equal". + kEqualOrdered = 0x03u << 2, //!< The arithmetic comparison is "equal". + + // Polarity + // -------- + + kPosPolarity = 0x00u << 4, //!< IntRes2 = IntRes1. + kNegPolarity = 0x01u << 4, //!< IntRes2 = -1 XOR IntRes1. + kPosMasked = 0x02u << 4, //!< IntRes2 = IntRes1. + kNegMasked = 0x03u << 4, //!< IntRes2[i] = second[i] == invalid ? IntRes1[i] : ~IntRes1[i]. + + // Output Selection (pcmpstri) + // --------------------------- + + kOutputLSI = 0x00u << 6, //!< The index returned to ECX is of the least significant set bit in IntRes2. + kOutputMSI = 0x01u << 6, //!< The index returned to ECX is of the most significant set bit in IntRes2. + + // Output Selection (pcmpstrm) + // --------------------------- + + kBitMask = 0x00u << 6, //!< IntRes2 is returned as the mask to the least significant bits of XMM0. + kIndexMask = 0x01u << 6 //!< IntRes2 is expanded into a byte/word mask and placed in XMM0. +}; +ASMJIT_DEFINE_ENUM_FLAGS(PCmpStrImm) + +//! An immediate value that can be used with ROUND[PD|PS|SD|SS] instructions. +//! +//! \note `kSuppress` is a mask that can be used with any other value. +enum class RoundImm : uint8_t { + kNearest = 0x00u, //!< Round to nearest (even). + kDown = 0x01u, //!< Round to down toward -INF (floor), + kUp = 0x02u, //!< Round to up toward +INF (ceil). + kTrunc = 0x03u, //!< Round toward zero (truncate). + kCurrent = 0x04u, //!< Round to the current rounding mode set (ignores other RC bits). + kSuppress = 0x08u //!< Suppress exceptions (avoids inexact exception, if set). +}; +ASMJIT_DEFINE_ENUM_FLAGS(RoundImm) + +//! An immediate value that can be used with VCMP[PD|PS|SD|SS] instructions (AVX). +//! +//! The first 8 values are compatible with \ref CmpImm. +enum class VCmpImm : uint8_t { + kEQ_OQ = 0x00u, //!< Equal (Quiet , Ordered) , same as \ref CmpImm::kEQ. + kLT_OS = 0x01u, //!< Less (Signaling, Ordered) , same as \ref CmpImm::kLT. + kLE_OS = 0x02u, //!< Less/Equal (Signaling, Ordered) , same as \ref CmpImm::kLE. + kUNORD_Q = 0x03u, //!< Unordered (Quiet) , same as \ref CmpImm::kUNORD. + kNEQ_UQ = 0x04u, //!< Not Equal (Quiet , Unordered), same as \ref CmpImm::kNEQ. + kNLT_US = 0x05u, //!< Not Less (Signaling, Unordered), same as \ref CmpImm::kNLT. + kNLE_US = 0x06u, //!< Not Less/Equal (Signaling, Unordered), same as \ref CmpImm::kNLE. + kORD_Q = 0x07u, //!< Ordered (Quiet) , same as \ref CmpImm::kORD. + kEQ_UQ = 0x08u, //!< Equal (Quiet , Unordered). + kNGE_US = 0x09u, //!< Not Greater/Equal (Signaling, Unordered). + kNGT_US = 0x0Au, //!< Not Greater (Signaling, Unordered). + kFALSE_OQ = 0x0Bu, //!< False (Quiet , Ordered). + kNEQ_OQ = 0x0Cu, //!< Not Equal (Quiet , Ordered). + kGE_OS = 0x0Du, //!< Greater/Equal (Signaling, Ordered). + kGT_OS = 0x0Eu, //!< Greater (Signaling, Ordered). + kTRUE_UQ = 0x0Fu, //!< True (Quiet , Unordered). + kEQ_OS = 0x10u, //!< Equal (Signaling, Ordered). + kLT_OQ = 0x11u, //!< Less (Quiet , Ordered). + kLE_OQ = 0x12u, //!< Less/Equal (Quiet , Ordered). + kUNORD_S = 0x13u, //!< Unordered (Signaling). + kNEQ_US = 0x14u, //!< Not Equal (Signaling, Unordered). + kNLT_UQ = 0x15u, //!< Not Less (Quiet , Unordered). + kNLE_UQ = 0x16u, //!< Not Less/Equal (Quiet , Unordered). + kORD_S = 0x17u, //!< Ordered (Signaling). + kEQ_US = 0x18u, //!< Equal (Signaling, Unordered). + kNGE_UQ = 0x19u, //!< Not Greater/Equal (Quiet , Unordered). + kNGT_UQ = 0x1Au, //!< Not Greater (Quiet , Unordered). + kFALSE_OS = 0x1Bu, //!< False (Signaling, Ordered). + kNEQ_OS = 0x1Cu, //!< Not Equal (Signaling, Ordered). + kGE_OQ = 0x1Du, //!< Greater/Equal (Quiet , Ordered). + kGT_OQ = 0x1Eu, //!< Greater (Quiet , Ordered). + kTRUE_US = 0x1Fu //!< True (Signaling, Unordered). +}; + +//! An immediate value that can be used with VFIXUPIMM[PD|PS|SD|SS] instructions (AVX-512). +//! +//! The final immediate is a combination of all possible control bits. +enum class VFixupImm : uint8_t { + kNone = 0x00u, + kZEOnZero = 0x01u, + kIEOnZero = 0x02u, + kZEOnOne = 0x04u, + kIEOnOne = 0x08u, + kIEOnSNaN = 0x10u, + kIEOnNInf = 0x20u, + kIEOnNegative = 0x40u, + kIEOnPInf = 0x80u +}; +ASMJIT_DEFINE_ENUM_FLAGS(VFixupImm) + +//! An immediate value that can be used with VFPCLASS[PD|PS|SD|SS] instructions (AVX-512). +//! +//! The values can be combined together to form the final 8-bit mask. +enum class VFPClassImm : uint8_t { + kNone = 0x00u, + kQNaN = 0x01u, //!< Checks for QNaN. + kPZero = 0x02u, //!< Checks for +0. + kNZero = 0x04u, //!< Checks for -0. + kPInf = 0x08u, //!< Checks for +Inf. + kNInf = 0x10u, //!< Checks for -Inf. + kDenormal = 0x20u, //!< Checks for denormal. + kNegative = 0x40u, //!< Checks for negative finite value. + kSNaN = 0x80u //!< Checks for SNaN. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VFPClassImm) + +//! An immediate value that can be used with VGETMANT[PD|PS|SD|SS] instructions (AVX-512). +//! +//! The value is a combination of a normalization interval and a sign control. +enum class VGetMantImm : uint8_t { + // Normalization Interval + // ---------------------- + + k1To2 = 0x00u, //!< Normalization interval is [1, 2) + k1Div2To2 = 0x01u, //!< Normalization interval is [0.5, 2) + k1Div2To1 = 0x02u, //!< Normalization interval is [0.5, 1) + k3Div4To3Div2 = 0x03u, //!< Normalization interval is [3/4, 3/2) + + // Sign Control + // ------------ + + kSrcSign = 0x00u, //!< Source sign. + kNoSign = 0x04u, //!< Zero sign + kQNaNIfSign = 0x08u //!< QNAN_Indefinite if sign(src) != 0, regardless of `kSignSrc` or `kNoSign`. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VGetMantImm) + +//! A predicate used by VPCMP[U][B|W|D|Q] instructions (AVX-512). +enum class VPCmpImm : uint8_t { + kEQ = 0x00u, //!< Equal. + kLT = 0x01u, //!< Less. + kLE = 0x02u, //!< Less/Equal. + kFALSE = 0x03u, //!< False. + kNE = 0x04u, //!< Not Equal. + kGE = 0x05u, //!< Greater/Equal. + kGT = 0x06u, //!< Greater. + kTRUE = 0x07u //!< True. +}; + +//! A predicate used by VPCOM[U][B|W|D|Q] instructions (XOP). +enum class VPComImm : uint8_t { + kLT = 0x00u, //!< Less. + kLE = 0x01u, //!< Less/Equal + kGT = 0x02u, //!< Greater. + kGE = 0x03u, //!< Greater/Equal. + kEQ = 0x04u, //!< Equal. + kNE = 0x05u, //!< Not Equal. + kFALSE = 0x06u, //!< False. + kTRUE = 0x07u //!< True. +}; + +//! A predicate used by VRANGE[PD|PS|SD|SS] instructions (AVX-512). +enum class VRangeImm : uint8_t { + // Selector + // -------- + + kSelectMin = 0x00u, //!< Select minimum value. + kSelectMax = 0x01u, //!< Select maximum value. + kSelectAbsMin = 0x02u, //!< Select minimum absolute value. + kSelectAbsMax = 0x03u, //!< Select maximum absolute value. + + // Sign + // ---- + + kSignSrc1 = 0x00u, //!< Select sign of SRC1. + kSignSrc2 = 0x04u, //!< Select sign of SRC2. + kSign0 = 0x08u, //!< Set sign to 0. + kSign1 = 0x0Cu //!< Set sign to 1. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VRangeImm) + +//! A predicate used by VREDUCE[PD|PS|SD|SS] instructions (AVX-512). +enum class VReduceImm : uint8_t { + kRoundEven = 0x00u, //!< Round to nearest even. + kRoundDown = 0x01u, //!< Round down. + kRoundUp = 0x02u, //!< Round up. + kRoundTrunc = 0x03u, //!< Truncate. + kRoundCurrent = 0x04u, //!< Round to the current mode set. + kSuppress = 0x08u, //!< Suppress exceptions. + kFixedImmMask = 0xF0u //!< Fixed length value mask. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VReduceImm) + +//! Creates a \ref VReduceImm from a combination of `flags` and `fixedPointLength`. +static ASMJIT_INLINE_NODEBUG constexpr VReduceImm vReduceImm(VReduceImm flags, uint32_t fixedPointLength) noexcept { + return flags | VReduceImm(fixedPointLength << 4); +} + +//! A predicate that can be used as an immediate value with VPTERNLOG[D|Q] instruction. +//! +//! There are 3 inputs to the instruction (\ref kA, \ref kB, \ref kC). Ternary logic can define any combination +//! that would be performed on these 3 inputs to get the desired output - any combination of AND, OR, XOR, NOT +//! is possible. +//! +//! \sa \ref tLogFromBits and \ref fLogIfElse +enum class TLogImm : uint8_t { + k0 = 0x00u, //!< 0 value. + k1 = 0xFFu, //!< 1 value. + kA = 0xF0u, //!< A value. + kB = 0xCCu, //!< B value. + kC = 0xAAu, //!< C value. + + kNotA = kA ^ k1, //!< `!A` expression. + kNotB = kB ^ k1, //!< `!B` expression. + kNotC = kC ^ k1, //!< `!C` expression. + + kAB = kA & kB, //!< `A & B` expression. + kAC = kA & kC, //!< `A & C` expression. + kBC = kB & kC, //!< `B & C` expression. + kNotAB = kAB ^ k1, //!< `!(A & B)` expression. + kNotAC = kAC ^ k1, //!< `!(A & C)` expression. + kNotBC = kBC ^ k1, //!< `!(B & C)` expression. + + kABC = kAB & kC, //!< `A & B & C` expression. + kNotABC = kABC ^ k1 //!< `!(A & B & C)` expression. +}; +ASMJIT_DEFINE_ENUM_FLAGS(TLogImm) + +//! Creates an immediate that can be used by VPTERNLOG[D|Q] instructions. +static ASMJIT_INLINE_NODEBUG constexpr TLogImm tLogFromBits(uint8_t b000, uint8_t b001, uint8_t b010, uint8_t b011, uint8_t b100, uint8_t b101, uint8_t b110, uint8_t b111) noexcept { + return TLogImm(uint8_t(b000 << 0) | + uint8_t(b001 << 1) | + uint8_t(b010 << 2) | + uint8_t(b011 << 3) | + uint8_t(b100 << 4) | + uint8_t(b101 << 5) | + uint8_t(b110 << 6) | + uint8_t(b111 << 7)); +} + +//! Creates an if/else logic that can be used by VPTERNLOG[D|Q] instructions. +static ASMJIT_INLINE_NODEBUG constexpr TLogImm fLogIfElse(TLogImm condition, TLogImm a, TLogImm b) noexcept { return (condition & a) | (~condition & b); } + +//! Creates a shuffle immediate value that be used with SSE/AVX/AVX-512 instructions to shuffle 2 elements in a vector. +//! +//! \param a Position of the first component [0, 1]. +//! \param b Position of the second component [0, 1]. +//! +//! Shuffle constants can be used to encode an immediate for these instructions: +//! - `shufpd|vshufpd` +static ASMJIT_INLINE_NODEBUG constexpr uint32_t shuffleImm(uint32_t a, uint32_t b) noexcept { + return (a << 1) | b; +} + +//! Creates a shuffle immediate value that be used with SSE/AVX/AVX-512 instructions to shuffle 4 elements in a vector. +//! +//! \param a Position of the first component [0, 3]. +//! \param b Position of the second component [0, 3]. +//! \param c Position of the third component [0, 3]. +//! \param d Position of the fourth component [0, 3]. +//! +//! Shuffle constants can be used to encode an immediate for these instructions: +//! - `pshufw` +//! - `pshuflw|vpshuflw` +//! - `pshufhw|vpshufhw` +//! - `pshufd|vpshufd` +//! - `shufps|vshufps` +static ASMJIT_INLINE_NODEBUG constexpr uint32_t shuffleImm(uint32_t a, uint32_t b, uint32_t c, uint32_t d) noexcept { + return (a << 6) | (b << 4) | (c << 2) | d; +} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86GLOBALS_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86instdb.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86instdb.h new file mode 100644 index 0000000000000000000000000000000000000000..2bf7d1489575472a483887a15453b9a6226b2b51 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86instdb.h @@ -0,0 +1,563 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86INSTDB_H_INCLUDED +#define ASMJIT_X86_X86INSTDB_H_INCLUDED + +#include "../x86/x86globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! Instruction database (X86). +namespace InstDB { + +//! Describes which operation mode is supported by an instruction. +enum class Mode : uint8_t { + //! Invalid mode. + kNone = 0x00u, + //! X86 mode supported. + kX86 = 0x01u, + //! X64 mode supported. + kX64 = 0x02u, + //! Both X86 and X64 modes supported. + kAny = 0x03u +}; +ASMJIT_DEFINE_ENUM_FLAGS(Mode) + +//! Converts architecture to operation mode, see \ref Mode. +static ASMJIT_INLINE_NODEBUG constexpr Mode modeFromArch(Arch arch) noexcept { + return arch == Arch::kX86 ? Mode::kX86 : + arch == Arch::kX64 ? Mode::kX64 : Mode::kNone; +} + +//! Operand signature flags used by \ref OpSignature. +enum class OpFlags : uint64_t { + //! No operand flags. + kNone = 0u, + + kRegGpbLo = 0x0000000000000001u, //!< Operand can be low 8-bit GPB register. + kRegGpbHi = 0x0000000000000002u, //!< Operand can be high 8-bit GPB register. + kRegGpw = 0x0000000000000004u, //!< Operand can be 16-bit GPW register. + kRegGpd = 0x0000000000000008u, //!< Operand can be 32-bit GPD register. + kRegGpq = 0x0000000000000010u, //!< Operand can be 64-bit GPQ register. + kRegXmm = 0x0000000000000020u, //!< Operand can be 128-bit XMM register. + kRegYmm = 0x0000000000000040u, //!< Operand can be 256-bit YMM register. + kRegZmm = 0x0000000000000080u, //!< Operand can be 512-bit ZMM register. + kRegMm = 0x0000000000000100u, //!< Operand can be 64-bit MM register. + kRegKReg = 0x0000000000000200u, //!< Operand can be 64-bit K register. + kRegSReg = 0x0000000000000400u, //!< Operand can be SReg (segment register). + kRegCReg = 0x0000000000000800u, //!< Operand can be CReg (control register). + kRegDReg = 0x0000000000001000u, //!< Operand can be DReg (debug register). + kRegSt = 0x0000000000002000u, //!< Operand can be 80-bit ST register (X87). + kRegBnd = 0x0000000000004000u, //!< Operand can be 128-bit BND register. + kRegTmm = 0x0000000000008000u, //!< Operand can be 0..8192-bit TMM register. + kRegMask = 0x000000000000FFFFu, //!< Mask of all possible register types. + + kMemUnspecified = 0x0000000000040000u, //!< Operand can be a scalar memory pointer without size. + kMem8 = 0x0000000000080000u, //!< Operand can be an 8-bit memory pointer. + kMem16 = 0x0000000000100000u, //!< Operand can be a 16-bit memory pointer. + kMem32 = 0x0000000000200000u, //!< Operand can be a 32-bit memory pointer. + kMem48 = 0x0000000000400000u, //!< Operand can be a 48-bit memory pointer (FAR pointers only). + kMem64 = 0x0000000000800000u, //!< Operand can be a 64-bit memory pointer. + kMem80 = 0x0000000001000000u, //!< Operand can be an 80-bit memory pointer. + kMem128 = 0x0000000002000000u, //!< Operand can be a 128-bit memory pointer. + kMem256 = 0x0000000004000000u, //!< Operand can be a 256-bit memory pointer. + kMem512 = 0x0000000008000000u, //!< Operand can be a 512-bit memory pointer. + kMem1024 = 0x0000000010000000u, //!< Operand can be a 1024-bit memory pointer. + kMemMask = 0x000000001FFC0000u, //!< Mask of all possible scalar memory types. + + kVm32x = 0x0000000040000000u, //!< Operand can be a vm32x (vector) pointer. + kVm32y = 0x0000000080000000u, //!< Operand can be a vm32y (vector) pointer. + kVm32z = 0x0000000100000000u, //!< Operand can be a vm32z (vector) pointer. + kVm64x = 0x0000000200000000u, //!< Operand can be a vm64x (vector) pointer. + kVm64y = 0x0000000400000000u, //!< Operand can be a vm64y (vector) pointer. + kVm64z = 0x0000000800000000u, //!< Operand can be a vm64z (vector) pointer. + kVmMask = 0x0000000FC0000000u, //!< Mask of all possible vector memory types. + + kImmI4 = 0x0000001000000000u, //!< Operand can be signed 4-bit immediate. + kImmU4 = 0x0000002000000000u, //!< Operand can be unsigned 4-bit immediate. + kImmI8 = 0x0000004000000000u, //!< Operand can be signed 8-bit immediate. + kImmU8 = 0x0000008000000000u, //!< Operand can be unsigned 8-bit immediate. + kImmI16 = 0x0000010000000000u, //!< Operand can be signed 16-bit immediate. + kImmU16 = 0x0000020000000000u, //!< Operand can be unsigned 16-bit immediate. + kImmI32 = 0x0000040000000000u, //!< Operand can be signed 32-bit immediate. + kImmU32 = 0x0000080000000000u, //!< Operand can be unsigned 32-bit immediate. + kImmI64 = 0x0000100000000000u, //!< Operand can be signed 64-bit immediate. + kImmU64 = 0x0000200000000000u, //!< Operand can be unsigned 64-bit immediate. + kImmMask = 0x00003FF000000000u, //!< Mask of all immediate types. + + kRel8 = 0x0000400000000000u, //!< Operand can be relative 8-bit displacement. + kRel32 = 0x0000800000000000u, //!< Operand can be relative 32-bit displacement. + kRelMask = 0x0000C00000000000u, //!< Mask of all relative displacement types. + + kFlagMemBase = 0x0001000000000000u, //!< Flag: Only memory base is allowed (no index, no offset). + kFlagMemDs = 0x0002000000000000u, //!< Flag: Implicit memory operand's DS segment. + kFlagMemEs = 0x0004000000000000u, //!< Flag: Implicit memory operand's ES segment. + + kFlagMib = 0x0008000000000000u, //!< Flag: Operand is MIB (base+index) pointer. + kFlagTMem = 0x0010000000000000u, //!< Flag: Operand is TMEM (sib_mem), AMX memory pointer. + + kFlagImplicit = 0x0080000000000000u, //!< Flag: Operand is implicit. + kFlagMask = 0x009F000000000000u, //!< Mask of all flags. + + //! Contains mask of all registers, memory operands, immediate operands, and displacement operands. + kOpMask = kRegMask | kMemMask | kVmMask | kImmMask | kRelMask +}; +ASMJIT_DEFINE_ENUM_FLAGS(OpFlags) + +//! Operand signature. +//! +//! Contains all possible operand combinations, memory size information, and a fixed register id (or `BaseReg::kIdBad` +//! if fixed id isn't required). +struct OpSignature { + //! \name Members + //! \{ + + uint64_t _flags : 56; + uint64_t _regMask : 8; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns operand signature flags. + inline OpFlags flags() const noexcept { return (OpFlags)_flags; } + + //! Tests whether the given `flag` is set. + inline bool hasFlag(OpFlags flag) const noexcept { return (_flags & uint64_t(flag)) != 0; } + + //! Tests whether this signature contains at least one register operand of any type. + inline bool hasReg() const noexcept { return hasFlag(OpFlags::kRegMask); } + //! Tests whether this signature contains at least one scalar memory operand of any type. + inline bool hasMem() const noexcept { return hasFlag(OpFlags::kMemMask); } + //! Tests whether this signature contains at least one vector memory operand of any type. + inline bool hasVm() const noexcept { return hasFlag(OpFlags::kVmMask); } + //! Tests whether this signature contains at least one immediate operand of any type. + inline bool hasImm() const noexcept { return hasFlag(OpFlags::kImmMask); } + //! Tests whether this signature contains at least one relative displacement operand of any type. + inline bool hasRel() const noexcept { return hasFlag(OpFlags::kRelMask); } + + //! Tests whether the operand is implicit. + inline bool isImplicit() const noexcept { return hasFlag(OpFlags::kFlagImplicit); } + + //! Returns a physical register mask. + inline RegMask regMask() const noexcept { return _regMask; } + + //! \} +}; + +ASMJIT_VARAPI const OpSignature _opSignatureTable[]; + +//! Instruction signature. +//! +//! Contains a sequence of operands' combinations and other metadata that defines a single instruction. This data is +//! used by instruction validator. +struct InstSignature { + //! \name Members + //! \{ + + //! Count of operands in `opIndex` (0..6). + uint8_t _opCount : 3; + //! Architecture modes supported (X86 / X64). + uint8_t _mode : 2; + //! Number of implicit operands. + uint8_t _implicitOpCount : 3; + //! Reserved for future use. + uint8_t _reserved; + //! Indexes to `OpSignature` table. + uint8_t _opSignatureIndexes[Globals::kMaxOpCount]; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns instruction operation mode. + inline Mode mode() const noexcept { return (Mode)_mode; } + //! Tests whether the instruction supports the given operating mode. + inline bool supportsMode(Mode mode) const noexcept { return (uint8_t(_mode) & uint8_t(mode)) != 0; } + + //! Returns the number of operands of this signature. + inline uint32_t opCount() const noexcept { return _opCount; } + //! Returns the number of implicit operands this signature has. + inline uint32_t implicitOpCount() const noexcept { return _implicitOpCount; } + //! Tests whether this instruction signature has at least one implicit operand. + inline bool hasImplicitOperands() const noexcept { return _implicitOpCount != 0; } + + //! Returns indexes to \ref _opSignatureTable for each operand of the instruction. + //! + //! \note The returned array always provides indexes for all operands (see \ref Globals::kMaxOpCount) even if the + //! instruction provides less operands. Undefined operands have always index of zero. + inline const uint8_t* opSignatureIndexes() const noexcept { return _opSignatureIndexes; } + + //! Returns index to \ref _opSignatureTable, corresponding to the requested operand `index` of the instruction. + inline uint8_t opSignatureIndex(size_t index) const noexcept { + ASMJIT_ASSERT(index < Globals::kMaxOpCount); + return _opSignatureIndexes[index]; + } + + //! Returns \ref OpSignature corresponding to the requested operand `index` of the instruction. + inline const OpSignature& opSignature(size_t index) const noexcept { + ASMJIT_ASSERT(index < Globals::kMaxOpCount); + return _opSignatureTable[_opSignatureIndexes[index]]; + } + + //! \} +}; + +ASMJIT_VARAPI const InstSignature _instSignatureTable[]; + +//! Instruction flags. +//! +//! Details about instruction encoding, operation, features, and some limitations. +enum class InstFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + + // Instruction Family + // ------------------ + // + // Instruction family information. + + //! Instruction that accesses FPU registers. + kFpu = 0x00000100u, + //! Instruction that accesses MMX registers (including 3DNOW and GEODE) and EMMS. + kMmx = 0x00000200u, + //! Instruction that accesses XMM registers (SSE, AVX, AVX512). + kVec = 0x00000400u, + + // FPU Flags + // --------- + // + // Used to tell the encoder which memory operand sizes are encodable. + + //! FPU instruction can address `word_ptr` (shared with M80). + kFpuM16 = 0x00000800u, + //! FPU instruction can address `dword_ptr`. + kFpuM32 = 0x00001000u, + //! FPU instruction can address `qword_ptr`. + kFpuM64 = 0x00002000u, + //! FPU instruction can address `tword_ptr` (shared with M16). + kFpuM80 = 0x00000800u, + + // Prefixes and Encoding Flags + // --------------------------- + // + // These describe optional X86 prefixes that can be used to change the instruction's operation. + + //! Instruction can be prefixed with using the REP(REPE) or REPNE prefix. + kRep = 0x00004000u, + //! Rep prefix is accepted, but it has no effect other than being emitted with the instruction (as an extra byte). + kRepIgnored = 0x00008000u, + //! Instruction can be prefixed with using the LOCK prefix. + kLock = 0x00010000u, + //! Instruction can be prefixed with using the XACQUIRE prefix. + kXAcquire = 0x00020000u, + //! Instruction can be prefixed with using the XRELEASE prefix. + kXRelease = 0x00040000u, + //! Instruction uses MIB (BNDLDX|BNDSTX) to encode two registers. + kMib = 0x00080000u, + //! Instruction uses VSIB instead of legacy SIB. + kVsib = 0x00100000u, + //! Instruction uses TSIB (or SIB_MEM) encoding (MODRM followed by SIB). + kTsib = 0x00200000u, + + // If both `kPrefixVex` and `kPrefixEvex` flags are specified it means that the instructions can be encoded + // by either VEX or EVEX prefix. In that case AsmJit checks global options and also instruction options to decide + // whether to emit VEX or EVEX prefix. + + //! Instruction can be encoded by VEX|XOP (AVX|AVX2|BMI|XOP|...). + kVex = 0x00400000u, + //! Instruction can be encoded by EVEX (AVX512). + kEvex = 0x00800000u, + //! EVEX encoding is preferred over VEX encoding (AVX515_VNNI vs AVX_VNNI). + kPreferEvex = 0x01000000u, + //! EVEX and VEX signatures are compatible. + kEvexCompat = 0x02000000u, + //! EVEX instruction requires K register in the first operand (compare instructions). + kEvexKReg = 0x04000000u, + //! EVEX instruction requires two operands and K register as a selector (gather instructions). + kEvexTwoOp = 0x08000000u, + //! VEX instruction that can be transformed to a compatible EVEX instruction. + kEvexTransformable = 0x10000000u, + + // Other Flags + // ----------- + + //! Instruction uses consecutive registers. + //! + //! Used by V4FMADDPS, V4FMADDSS, V4FNMADDPS, V4FNMADDSS, VP4DPWSSD, VP4DPWSSDS, VP2INTERSECTD, and VP2INTERSECTQ + //! instructions + kConsecutiveRegs = 0x20000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstFlags) + +//! AVX-512 flags. +enum class Avx512Flags : uint32_t { + //! No AVX-512 flags. + kNone = 0, + + //! Internally used in tables, has no meaning. + k_ = 0x00000000u, + //! Supports masking {k1..k7}. + kK = 0x00000001u, + //! Supports zeroing {z}, must be used together with `kAvx512k`. + kZ = 0x00000002u, + //! Supports 'embedded-rounding' {er} with implicit {sae}, + kER = 0x00000004u, + //! Supports 'suppress-all-exceptions' {sae}. + kSAE = 0x00000008u, + //! Supports 16-bit broadcast 'b16'. + kB16 = 0x00000010u, + //! Supports 32-bit broadcast 'b32'. + kB32 = 0x00000020u, + //! Supports 64-bit broadcast 'b64'. + kB64 = 0x00000040u, + //! Operates on a vector of consecutive registers (AVX512_4FMAPS and AVX512_4VNNIW). + kT4X = 0x00000080u, + + //! Implicit zeroing if {k} masking is used. Using {z} is not valid in this case as it's implicit. + kImplicitZ = 0x00000100, +}; +ASMJIT_DEFINE_ENUM_FLAGS(Avx512Flags) + +//! Instruction common information. +//! +//! Aggregated information shared across one or more instruction. +struct CommonInfo { + //! Instruction flags. + uint32_t _flags; + //! Reserved for future use. + uint32_t _avx512Flags : 11; + //! First `InstSignature` entry in the database. + uint32_t _iSignatureIndex : 11; + //! Number of relevant `ISignature` entries. + uint32_t _iSignatureCount : 5; + //! Instruction control flow category, see \ref InstControlFlow. + uint32_t _controlFlow : 3; + //! Specifies what happens if all source operands share the same register. + uint32_t _sameRegHint : 2; + + //! \name Accessors + //! \{ + + //! Returns instruction flags. + ASMJIT_INLINE_NODEBUG InstFlags flags() const noexcept { return (InstFlags)_flags; } + //! Tests whether the instruction has a `flag`. + ASMJIT_INLINE_NODEBUG bool hasFlag(InstFlags flag) const noexcept { return Support::test(_flags, flag); } + + //! Returns instruction AVX-512 flags. + ASMJIT_INLINE_NODEBUG Avx512Flags avx512Flags() const noexcept { return (Avx512Flags)_avx512Flags; } + //! Tests whether the instruction has an AVX-512 `flag`. + ASMJIT_INLINE_NODEBUG bool hasAvx512Flag(Avx512Flags flag) const noexcept { return Support::test(_avx512Flags, flag); } + + //! Tests whether the instruction is FPU instruction. + ASMJIT_INLINE_NODEBUG bool isFpu() const noexcept { return hasFlag(InstFlags::kFpu); } + //! Tests whether the instruction is MMX/3DNOW instruction that accesses MMX registers (includes EMMS and FEMMS). + ASMJIT_INLINE_NODEBUG bool isMmx() const noexcept { return hasFlag(InstFlags::kMmx); } + //! Tests whether the instruction is SSE|AVX|AVX512 instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isVec() const noexcept { return hasFlag(InstFlags::kVec); } + //! Tests whether the instruction is SSE+ (SSE4.2, AES, SHA included) instruction that accesses XMM registers. + ASMJIT_INLINE_NODEBUG bool isSse() const noexcept { return (flags() & (InstFlags::kVec | InstFlags::kVex | InstFlags::kEvex)) == InstFlags::kVec; } + //! Tests whether the instruction is AVX+ (FMA included) instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isAvx() const noexcept { return isVec() && isVexOrEvex(); } + + //! Tests whether the instruction can be prefixed with LOCK prefix. + ASMJIT_INLINE_NODEBUG bool hasLockPrefix() const noexcept { return hasFlag(InstFlags::kLock); } + //! Tests whether the instruction can be prefixed with REP (REPE|REPZ) prefix. + ASMJIT_INLINE_NODEBUG bool hasRepPrefix() const noexcept { return hasFlag(InstFlags::kRep); } + //! Tests whether the instruction can be prefixed with XACQUIRE prefix. + ASMJIT_INLINE_NODEBUG bool hasXAcquirePrefix() const noexcept { return hasFlag(InstFlags::kXAcquire); } + //! Tests whether the instruction can be prefixed with XRELEASE prefix. + ASMJIT_INLINE_NODEBUG bool hasXReleasePrefix() const noexcept { return hasFlag(InstFlags::kXRelease); } + + //! Tests whether the rep prefix is supported by the instruction, but ignored (has no effect). + ASMJIT_INLINE_NODEBUG bool isRepIgnored() const noexcept { return hasFlag(InstFlags::kRepIgnored); } + //! Tests whether the instruction uses MIB. + ASMJIT_INLINE_NODEBUG bool isMibOp() const noexcept { return hasFlag(InstFlags::kMib); } + //! Tests whether the instruction uses VSIB. + ASMJIT_INLINE_NODEBUG bool isVsibOp() const noexcept { return hasFlag(InstFlags::kVsib); } + //! Tests whether the instruction uses TSIB (AMX, instruction requires MOD+SIB). + ASMJIT_INLINE_NODEBUG bool isTsibOp() const noexcept { return hasFlag(InstFlags::kTsib); } + //! Tests whether the instruction uses VEX (can be set together with EVEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVex() const noexcept { return hasFlag(InstFlags::kVex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isEvex() const noexcept { return hasFlag(InstFlags::kEvex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVexOrEvex() const noexcept { return hasFlag(InstFlags::kVex | InstFlags::kEvex); } + + //! Tests whether the instruction should prefer EVEX prefix instead of VEX prefix. + ASMJIT_INLINE_NODEBUG bool preferEvex() const noexcept { return hasFlag(InstFlags::kPreferEvex); } + + ASMJIT_INLINE_NODEBUG bool isEvexCompatible() const noexcept { return hasFlag(InstFlags::kEvexCompat); } + ASMJIT_INLINE_NODEBUG bool isEvexKRegOnly() const noexcept { return hasFlag(InstFlags::kEvexKReg); } + ASMJIT_INLINE_NODEBUG bool isEvexTwoOpOnly() const noexcept { return hasFlag(InstFlags::kEvexTwoOp); } + ASMJIT_INLINE_NODEBUG bool isEvexTransformable() const noexcept { return hasFlag(InstFlags::kEvexTransformable); } + + //! Tests whether the instruction supports AVX512 masking {k}. + ASMJIT_INLINE_NODEBUG bool hasAvx512K() const noexcept { return hasAvx512Flag(Avx512Flags::kK); } + //! Tests whether the instruction supports AVX512 zeroing {k}{z}. + ASMJIT_INLINE_NODEBUG bool hasAvx512Z() const noexcept { return hasAvx512Flag(Avx512Flags::kZ); } + //! Tests whether the instruction supports AVX512 embedded-rounding {er}. + ASMJIT_INLINE_NODEBUG bool hasAvx512ER() const noexcept { return hasAvx512Flag(Avx512Flags::kER); } + //! Tests whether the instruction supports AVX512 suppress-all-exceptions {sae}. + ASMJIT_INLINE_NODEBUG bool hasAvx512SAE() const noexcept { return hasAvx512Flag(Avx512Flags::kSAE); } + //! Tests whether the instruction supports AVX512 broadcast (either 32-bit or 64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B() const noexcept { return hasAvx512Flag(Avx512Flags::kB16 | Avx512Flags::kB32 | Avx512Flags::kB64); } + //! Tests whether the instruction supports AVX512 broadcast (16-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B16() const noexcept { return hasAvx512Flag(Avx512Flags::kB16); } + //! Tests whether the instruction supports AVX512 broadcast (32-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B32() const noexcept { return hasAvx512Flag(Avx512Flags::kB32); } + //! Tests whether the instruction supports AVX512 broadcast (64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B64() const noexcept { return hasAvx512Flag(Avx512Flags::kB64); } + + // Returns the size of the broadcast - either 2, 4, or 8, or 0 if broadcast is not supported. + ASMJIT_INLINE_NODEBUG uint32_t broadcastSize() const noexcept { + constexpr uint32_t kShift = Support::ConstCTZ::value; + return (uint32_t(_avx512Flags) & uint32_t(Avx512Flags::kB16 | Avx512Flags::kB32 | Avx512Flags::kB64)) >> (kShift - 1); + } + + ASMJIT_INLINE_NODEBUG uint32_t signatureIndex() const noexcept { return _iSignatureIndex; } + ASMJIT_INLINE_NODEBUG uint32_t signatureCount() const noexcept { return _iSignatureCount; } + + ASMJIT_INLINE_NODEBUG const InstSignature* signatureData() const noexcept { return _instSignatureTable + _iSignatureIndex; } + ASMJIT_INLINE_NODEBUG const InstSignature* signatureEnd() const noexcept { return _instSignatureTable + _iSignatureIndex + _iSignatureCount; } + + //! Returns a control flow category of the instruction. + ASMJIT_INLINE_NODEBUG InstControlFlow controlFlow() const noexcept { return (InstControlFlow)_controlFlow; } + + //! Returns a hint that can be used when both inputs are the same register. + ASMJIT_INLINE_NODEBUG InstSameRegHint sameRegHint() const noexcept { return (InstSameRegHint)_sameRegHint; } + + //! \} +}; + +ASMJIT_VARAPI const CommonInfo _commonInfoTable[]; + +//! Instruction information. +struct InstInfo { + //! Reserved for future use. + uint32_t _reserved : 14; + //! Index to \ref _commonInfoTable. + uint32_t _commonInfoIndex : 10; + //! Index to \ref _additionalInfoTable. + uint32_t _additionalInfoIndex : 8; + + //! Instruction encoding (internal encoding identifier used by \ref Assembler). + uint8_t _encoding; + //! Main opcode value (0..255). + uint8_t _mainOpcodeValue; + //! Index to \ref _mainOpcodeTable` that is combined with \ref _mainOpcodeValue to form the final opcode. + uint8_t _mainOpcodeIndex; + //! Index to \ref _altOpcodeTable that contains a full alternative opcode. + uint8_t _altOpcodeIndex; + + //! \name Accessors + //! \{ + + //! Returns common information, see \ref CommonInfo. + ASMJIT_INLINE_NODEBUG const CommonInfo& commonInfo() const noexcept { return _commonInfoTable[_commonInfoIndex]; } + + //! Returns instruction flags, see \ref InstFlags. + ASMJIT_INLINE_NODEBUG InstFlags flags() const noexcept { return commonInfo().flags(); } + //! Tests whether the instruction has flag `flag`, see \ref InstFlags. + ASMJIT_INLINE_NODEBUG bool hasFlag(InstFlags flag) const noexcept { return commonInfo().hasFlag(flag); } + + //! Returns instruction AVX-512 flags, see \ref Avx512Flags. + ASMJIT_INLINE_NODEBUG Avx512Flags avx512Flags() const noexcept { return commonInfo().avx512Flags(); } + //! Tests whether the instruction has an AVX-512 `flag`, see \ref Avx512Flags. + ASMJIT_INLINE_NODEBUG bool hasAvx512Flag(Avx512Flags flag) const noexcept { return commonInfo().hasAvx512Flag(flag); } + + //! Tests whether the instruction is FPU instruction. + ASMJIT_INLINE_NODEBUG bool isFpu() const noexcept { return commonInfo().isFpu(); } + //! Tests whether the instruction is MMX/3DNOW instruction that accesses MMX registers (includes EMMS and FEMMS). + ASMJIT_INLINE_NODEBUG bool isMmx() const noexcept { return commonInfo().isMmx(); } + //! Tests whether the instruction is SSE|AVX|AVX512 instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isVec() const noexcept { return commonInfo().isVec(); } + //! Tests whether the instruction is SSE+ (SSE4.2, AES, SHA included) instruction that accesses XMM registers. + ASMJIT_INLINE_NODEBUG bool isSse() const noexcept { return commonInfo().isSse(); } + //! Tests whether the instruction is AVX+ (FMA included) instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isAvx() const noexcept { return commonInfo().isAvx(); } + + //! Tests whether the instruction can be prefixed with LOCK prefix. + ASMJIT_INLINE_NODEBUG bool hasLockPrefix() const noexcept { return commonInfo().hasLockPrefix(); } + //! Tests whether the instruction can be prefixed with REP (REPE|REPZ) prefix. + ASMJIT_INLINE_NODEBUG bool hasRepPrefix() const noexcept { return commonInfo().hasRepPrefix(); } + //! Tests whether the instruction can be prefixed with XACQUIRE prefix. + ASMJIT_INLINE_NODEBUG bool hasXAcquirePrefix() const noexcept { return commonInfo().hasXAcquirePrefix(); } + //! Tests whether the instruction can be prefixed with XRELEASE prefix. + ASMJIT_INLINE_NODEBUG bool hasXReleasePrefix() const noexcept { return commonInfo().hasXReleasePrefix(); } + + //! Tests whether the rep prefix is supported by the instruction, but ignored (has no effect). + ASMJIT_INLINE_NODEBUG bool isRepIgnored() const noexcept { return commonInfo().isRepIgnored(); } + //! Tests whether the instruction uses MIB. + ASMJIT_INLINE_NODEBUG bool isMibOp() const noexcept { return hasFlag(InstFlags::kMib); } + //! Tests whether the instruction uses VSIB. + ASMJIT_INLINE_NODEBUG bool isVsibOp() const noexcept { return hasFlag(InstFlags::kVsib); } + //! Tests whether the instruction uses VEX (can be set together with EVEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVex() const noexcept { return hasFlag(InstFlags::kVex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isEvex() const noexcept { return hasFlag(InstFlags::kEvex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVexOrEvex() const noexcept { return hasFlag(InstFlags::kVex | InstFlags::kEvex); } + + ASMJIT_INLINE_NODEBUG bool isEvexCompatible() const noexcept { return hasFlag(InstFlags::kEvexCompat); } + ASMJIT_INLINE_NODEBUG bool isEvexKRegOnly() const noexcept { return hasFlag(InstFlags::kEvexKReg); } + ASMJIT_INLINE_NODEBUG bool isEvexTwoOpOnly() const noexcept { return hasFlag(InstFlags::kEvexTwoOp); } + ASMJIT_INLINE_NODEBUG bool isEvexTransformable() const noexcept { return hasFlag(InstFlags::kEvexTransformable); } + + //! Tests whether the instruction supports AVX512 masking {k}. + ASMJIT_INLINE_NODEBUG bool hasAvx512K() const noexcept { return hasAvx512Flag(Avx512Flags::kK); } + //! Tests whether the instruction supports AVX512 zeroing {k}{z}. + ASMJIT_INLINE_NODEBUG bool hasAvx512Z() const noexcept { return hasAvx512Flag(Avx512Flags::kZ); } + //! Tests whether the instruction supports AVX512 embedded-rounding {er}. + ASMJIT_INLINE_NODEBUG bool hasAvx512ER() const noexcept { return hasAvx512Flag(Avx512Flags::kER); } + //! Tests whether the instruction supports AVX512 suppress-all-exceptions {sae}. + ASMJIT_INLINE_NODEBUG bool hasAvx512SAE() const noexcept { return hasAvx512Flag(Avx512Flags::kSAE); } + //! Tests whether the instruction supports AVX512 broadcast (either 32-bit or 64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B() const noexcept { return hasAvx512Flag(Avx512Flags::kB16 | Avx512Flags::kB32 | Avx512Flags::kB64); } + //! Tests whether the instruction supports AVX512 broadcast (16-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B16() const noexcept { return hasAvx512Flag(Avx512Flags::kB16); } + //! Tests whether the instruction supports AVX512 broadcast (32-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B32() const noexcept { return hasAvx512Flag(Avx512Flags::kB32); } + //! Tests whether the instruction supports AVX512 broadcast (64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B64() const noexcept { return hasAvx512Flag(Avx512Flags::kB64); } + + //! Returns a control flow category of the instruction. + ASMJIT_INLINE_NODEBUG InstControlFlow controlFlow() const noexcept { return commonInfo().controlFlow(); } + //! Returns a hint that can be used when both inputs are the same register. + ASMJIT_INLINE_NODEBUG InstSameRegHint sameRegHint() const noexcept { return commonInfo().sameRegHint(); } + + ASMJIT_INLINE_NODEBUG uint32_t signatureIndex() const noexcept { return commonInfo().signatureIndex(); } + ASMJIT_INLINE_NODEBUG uint32_t signatureCount() const noexcept { return commonInfo().signatureCount(); } + + ASMJIT_INLINE_NODEBUG const InstSignature* signatureData() const noexcept { return commonInfo().signatureData(); } + ASMJIT_INLINE_NODEBUG const InstSignature* signatureEnd() const noexcept { return commonInfo().signatureEnd(); } + + //! \} +}; + +ASMJIT_VARAPI const InstInfo _instInfoTable[]; + +static inline const InstInfo& infoById(InstId instId) noexcept { + ASMJIT_ASSERT(Inst::isDefinedId(instId)); + return _instInfoTable[instId]; +} + +//! \cond INTERNAL +static_assert(sizeof(OpSignature) == 8, "InstDB::OpSignature must be 8 bytes long"); +//! \endcond + +} // {InstDB} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86INSTDB_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86operand.h b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86operand.h new file mode 100644 index 0000000000000000000000000000000000000000..ac56731c98265be60b70a65991d6e25c0761e9a6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/asmjit/x86/x86operand.h @@ -0,0 +1,1145 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86OPERAND_H_INCLUDED +#define ASMJIT_X86_X86OPERAND_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../x86/x86globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +class Reg; +class Mem; + +class Gp; +class Gpb; +class GpbLo; +class GpbHi; +class Gpw; +class Gpd; +class Gpq; +class Vec; +class Xmm; +class Ymm; +class Zmm; +class Mm; +class KReg; +class SReg; +class CReg; +class DReg; +class St; +class Bnd; +class Tmm; +class Rip; + +//! Register traits (X86). +//! +//! Register traits contains information about a particular register type. It's used by asmjit to setup register +//! information on-the-fly and to populate tables that contain register information (this way it's possible to change +//! register types and groups without having to reorder these tables). +template +struct RegTraits : public BaseRegTraits {}; + +//! \cond +// <--------------------+------------------------+------------------------+---+------------------+ +// | Reg-Type | Reg-Group |Sz | TypeId | +// <--------------------+------------------------+------------------------+---+------------------+ +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Rip , RegGroup::kX86_Rip , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_GpbLo , RegGroup::kGp , 1 , TypeId::kInt8 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_GpbHi , RegGroup::kGp , 1 , TypeId::kInt8 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Gpw , RegGroup::kGp , 2 , TypeId::kInt16 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Gpd , RegGroup::kGp , 4 , TypeId::kInt32 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Gpq , RegGroup::kGp , 8 , TypeId::kInt64 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Xmm , RegGroup::kVec , 16, TypeId::kInt32x4 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Ymm , RegGroup::kVec , 32, TypeId::kInt32x8 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Zmm , RegGroup::kVec , 64, TypeId::kInt32x16); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_KReg , RegGroup::kX86_K , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Mm , RegGroup::kX86_MM , 8 , TypeId::kMmx64 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_SReg , RegGroup::kX86_SReg , 2 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_CReg , RegGroup::kX86_CReg , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_DReg , RegGroup::kX86_DReg , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_St , RegGroup::kX86_St , 10, TypeId::kFloat80 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Bnd , RegGroup::kX86_Bnd , 16, TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Tmm , RegGroup::kX86_Tmm , 0 , TypeId::kVoid ); +//! \endcond + +//! Register (X86). +class Reg : public BaseReg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Reg, BaseReg) + + //! Tests whether the register is a GPB register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpb() const noexcept { return size() == 1; } + //! Tests whether the register is a low GPB register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpbLo() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a high GPB register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpbHi() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a GPW register (16-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpw() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a GPD register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpd() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a GPQ register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpq() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is a 32-bit general purpose register, alias of \ref isGpd(). + ASMJIT_INLINE_NODEBUG constexpr bool isGp32() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a 64-bit general purpose register, alias of \ref isGpq() + ASMJIT_INLINE_NODEBUG constexpr bool isGp64() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is an XMM register (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isXmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a YMM register (256-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isYmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a ZMM register (512-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isZmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is a 128-bit vector register, alias of \ref isXmm(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec128() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a 256-bit vector register, alias of \ref isYmm(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec256() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a 512-bit vector register, alias of \ref isZmm(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec512() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is an MMX register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isMm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a K register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isKReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a segment register. + ASMJIT_INLINE_NODEBUG constexpr bool isSReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a control register. + ASMJIT_INLINE_NODEBUG constexpr bool isCReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a debug register. + ASMJIT_INLINE_NODEBUG constexpr bool isDReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is an FPU register (80-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isSt() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a bound register. + ASMJIT_INLINE_NODEBUG constexpr bool isBnd() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a TMM register. + ASMJIT_INLINE_NODEBUG constexpr bool isTmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is RIP. + ASMJIT_INLINE_NODEBUG constexpr bool isRip() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + template + ASMJIT_INLINE_NODEBUG void setRegT(uint32_t rId) noexcept { + setSignature(OperandSignature{RegTraits::kSignature}); + setId(rId); + } + + ASMJIT_INLINE_NODEBUG void setTypeAndId(RegType type, uint32_t id) noexcept { + setSignature(signatureOf(type)); + setId(id); + } + + static ASMJIT_INLINE_NODEBUG RegGroup groupOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kX86).regTypeToGroup(type); } + static ASMJIT_INLINE_NODEBUG TypeId typeIdOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kX86).regTypeToTypeId(type); } + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kX86).regTypeToSignature(type); } + + template + static ASMJIT_INLINE_NODEBUG RegGroup groupOfT() noexcept { return RegGroup(RegTraits::kGroup); } + + template + static ASMJIT_INLINE_NODEBUG TypeId typeIdOfT() noexcept { return TypeId(RegTraits::kTypeId); } + + template + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfT() noexcept { return OperandSignature{RegTraits::kSignature}; } + + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfVecByType(TypeId typeId) noexcept { + return OperandSignature{typeId <= TypeId::_kVec128End ? uint32_t(RegTraits::kSignature) : + typeId <= TypeId::_kVec256End ? uint32_t(RegTraits::kSignature) : + uint32_t(RegTraits::kSignature)}; + } + + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfVecBySize(uint32_t size) noexcept { + return OperandSignature{size <= 16 ? uint32_t(RegTraits::kSignature) : + size <= 32 ? uint32_t(RegTraits::kSignature) : + uint32_t(RegTraits::kSignature)}; + } + + //! Tests whether the `op` operand is either a low or high 8-bit GPB register. + static ASMJIT_INLINE_NODEBUG bool isGpb(const Operand_& op) noexcept { + // Check operand type, register group, and size. Not interested in register type. + return op.signature().subset(Signature::kOpTypeMask | Signature::kRegGroupMask | Signature::kSizeMask) == + (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(RegGroup::kGp) | Signature::fromSize(1)); + } + + static ASMJIT_INLINE_NODEBUG bool isGpbLo(const Operand_& op) noexcept { return op.as().isGpbLo(); } + static ASMJIT_INLINE_NODEBUG bool isGpbHi(const Operand_& op) noexcept { return op.as().isGpbHi(); } + static ASMJIT_INLINE_NODEBUG bool isGpw(const Operand_& op) noexcept { return op.as().isGpw(); } + static ASMJIT_INLINE_NODEBUG bool isGpd(const Operand_& op) noexcept { return op.as().isGpd(); } + static ASMJIT_INLINE_NODEBUG bool isGpq(const Operand_& op) noexcept { return op.as().isGpq(); } + static ASMJIT_INLINE_NODEBUG bool isXmm(const Operand_& op) noexcept { return op.as().isXmm(); } + static ASMJIT_INLINE_NODEBUG bool isYmm(const Operand_& op) noexcept { return op.as().isYmm(); } + static ASMJIT_INLINE_NODEBUG bool isZmm(const Operand_& op) noexcept { return op.as().isZmm(); } + static ASMJIT_INLINE_NODEBUG bool isMm(const Operand_& op) noexcept { return op.as().isMm(); } + static ASMJIT_INLINE_NODEBUG bool isKReg(const Operand_& op) noexcept { return op.as().isKReg(); } + static ASMJIT_INLINE_NODEBUG bool isSReg(const Operand_& op) noexcept { return op.as().isSReg(); } + static ASMJIT_INLINE_NODEBUG bool isCReg(const Operand_& op) noexcept { return op.as().isCReg(); } + static ASMJIT_INLINE_NODEBUG bool isDReg(const Operand_& op) noexcept { return op.as().isDReg(); } + static ASMJIT_INLINE_NODEBUG bool isSt(const Operand_& op) noexcept { return op.as().isSt(); } + static ASMJIT_INLINE_NODEBUG bool isBnd(const Operand_& op) noexcept { return op.as().isBnd(); } + static ASMJIT_INLINE_NODEBUG bool isTmm(const Operand_& op) noexcept { return op.as().isTmm(); } + static ASMJIT_INLINE_NODEBUG bool isRip(const Operand_& op) noexcept { return op.as().isRip(); } + + static ASMJIT_INLINE_NODEBUG bool isGpb(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpb(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpbLo(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpbLo(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpbHi(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpbHi(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpw(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpw(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpd(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpd(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpq(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpq(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isXmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isXmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isYmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isYmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isZmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isZmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isMm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isMm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isKReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isKReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isSReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isSReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isCReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isCReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isDReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isDReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isSt(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isSt(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isBnd(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isBnd(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isTmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isTmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isRip(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isRip(op)) & unsigned(op.id() == rId)); } +}; + +//! General purpose register (X86). +class Gp : public Reg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Gp, Reg) + + //! Physical id (X86). + //! + //! \note Register indexes have been reduced to only support general purpose registers. There is no need to + //! have enumerations with number suffix that expands to the exactly same value as the suffix value itself. + enum Id : uint32_t { + kIdAx = 0, //!< Physical id of AL|AH|AX|EAX|RAX registers. + kIdCx = 1, //!< Physical id of CL|CH|CX|ECX|RCX registers. + kIdDx = 2, //!< Physical id of DL|DH|DX|EDX|RDX registers. + kIdBx = 3, //!< Physical id of BL|BH|BX|EBX|RBX registers. + kIdSp = 4, //!< Physical id of SPL|SP|ESP|RSP registers. + kIdBp = 5, //!< Physical id of BPL|BP|EBP|RBP registers. + kIdSi = 6, //!< Physical id of SIL|SI|ESI|RSI registers. + kIdDi = 7, //!< Physical id of DIL|DI|EDI|RDI registers. + kIdR8 = 8, //!< Physical id of R8B|R8W|R8D|R8 registers (64-bit only). + kIdR9 = 9, //!< Physical id of R9B|R9W|R9D|R9 registers (64-bit only). + kIdR10 = 10, //!< Physical id of R10B|R10W|R10D|R10 registers (64-bit only). + kIdR11 = 11, //!< Physical id of R11B|R11W|R11D|R11 registers (64-bit only). + kIdR12 = 12, //!< Physical id of R12B|R12W|R12D|R12 registers (64-bit only). + kIdR13 = 13, //!< Physical id of R13B|R13W|R13D|R13 registers (64-bit only). + kIdR14 = 14, //!< Physical id of R14B|R14W|R14D|R14 registers (64-bit only). + kIdR15 = 15 //!< Physical id of R15B|R15W|R15D|R15 registers (64-bit only). + }; + + //! Casts this register to 8-bit (LO) part. + ASMJIT_INLINE_NODEBUG GpbLo r8() const noexcept; + //! Casts this register to 8-bit (LO) part. + ASMJIT_INLINE_NODEBUG GpbLo r8Lo() const noexcept; + //! Casts this register to 8-bit (HI) part. + ASMJIT_INLINE_NODEBUG GpbHi r8Hi() const noexcept; + //! Casts this register to 16-bit. + ASMJIT_INLINE_NODEBUG Gpw r16() const noexcept; + //! Casts this register to 32-bit. + ASMJIT_INLINE_NODEBUG Gpd r32() const noexcept; + //! Casts this register to 64-bit. + ASMJIT_INLINE_NODEBUG Gpq r64() const noexcept; +}; + +//! Vector register (XMM|YMM|ZMM) (X86). +class Vec : public Reg { + ASMJIT_DEFINE_ABSTRACT_REG(Vec, Reg) + + //! Casts this register to XMM (clone). + ASMJIT_INLINE_NODEBUG Xmm xmm() const noexcept; + //! Casts this register to YMM (clone). + ASMJIT_INLINE_NODEBUG Ymm ymm() const noexcept; + //! Casts this register to ZMM (clone). + ASMJIT_INLINE_NODEBUG Zmm zmm() const noexcept; + + //! Casts this register to XMM (clone). + ASMJIT_INLINE_NODEBUG Vec v128() const noexcept; + //! Casts this register to YMM (clone). + ASMJIT_INLINE_NODEBUG Vec v256() const noexcept; + //! Casts this register to ZMM (clone). + ASMJIT_INLINE_NODEBUG Vec v512() const noexcept; + + //! Casts this register to a register that has half the size (or XMM if it's already XMM). + ASMJIT_INLINE_NODEBUG Vec half() const noexcept { + return Vec(type() == RegType::kX86_Zmm ? signatureOfT() : signatureOfT(), id()); + } +}; + +//! Segment register (X86). +class SReg : public Reg { + ASMJIT_DEFINE_FINAL_REG(SReg, Reg, RegTraits) + + //! X86 segment id. + enum Id : uint32_t { + //! No segment (default). + kIdNone = 0, + //! ES segment. + kIdEs = 1, + //! CS segment. + kIdCs = 2, + //! SS segment. + kIdSs = 3, + //! DS segment. + kIdDs = 4, + //! FS segment. + kIdFs = 5, + //! GS segment. + kIdGs = 6, + + //! Count of X86 segment registers supported by AsmJit. + //! + //! \note X86 architecture has 6 segment registers - ES, CS, SS, DS, FS, GS. X64 architecture lowers them down to + //! just FS and GS. AsmJit supports 7 segment registers - all addressable in both X86 and X64 modes and one extra + //! called `SReg::kIdNone`, which is AsmJit specific and means that there is no segment register specified. + kIdCount = 7 + }; +}; + +//! GPB low or high register (X86). +class Gpb : public Gp { ASMJIT_DEFINE_ABSTRACT_REG(Gpb, Gp) }; +//! GPB low register (X86). +class GpbLo : public Gpb { ASMJIT_DEFINE_FINAL_REG(GpbLo, Gpb, RegTraits) }; +//! GPB high register (X86). +class GpbHi : public Gpb { ASMJIT_DEFINE_FINAL_REG(GpbHi, Gpb, RegTraits) }; +//! GPW register (X86). +class Gpw : public Gp { ASMJIT_DEFINE_FINAL_REG(Gpw, Gp, RegTraits) }; +//! GPD register (X86). +class Gpd : public Gp { ASMJIT_DEFINE_FINAL_REG(Gpd, Gp, RegTraits) }; +//! GPQ register (X86_64). +class Gpq : public Gp { ASMJIT_DEFINE_FINAL_REG(Gpq, Gp, RegTraits) }; + +//! 128-bit XMM register (SSE+). +class Xmm : public Vec { + ASMJIT_DEFINE_FINAL_REG(Xmm, Vec, RegTraits) + //! Casts this register to a register that has half the size (XMM). + ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { return Xmm(id()); } +}; + +//! 256-bit YMM register (AVX+). +class Ymm : public Vec { + ASMJIT_DEFINE_FINAL_REG(Ymm, Vec, RegTraits) + //! Casts this register to a register that has half the size (XMM). + ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { return Xmm(id()); } +}; + +//! 512-bit ZMM register (AVX512+). +class Zmm : public Vec { + ASMJIT_DEFINE_FINAL_REG(Zmm, Vec, RegTraits) + //! Casts this register to a register that has half the size (YMM). + ASMJIT_INLINE_NODEBUG Ymm half() const noexcept { return Ymm(id()); } +}; + +//! 64-bit MMX register (MMX+). +class Mm : public Reg { ASMJIT_DEFINE_FINAL_REG(Mm, Reg, RegTraits) }; +//! 64-bit K register (AVX512+). +class KReg : public Reg { ASMJIT_DEFINE_FINAL_REG(KReg, Reg, RegTraits) }; +//! 32-bit or 64-bit control register (X86). +class CReg : public Reg { ASMJIT_DEFINE_FINAL_REG(CReg, Reg, RegTraits) }; +//! 32-bit or 64-bit debug register (X86). +class DReg : public Reg { ASMJIT_DEFINE_FINAL_REG(DReg, Reg, RegTraits) }; +//! 80-bit FPU register (X86). +class St : public Reg { ASMJIT_DEFINE_FINAL_REG(St, Reg, RegTraits) }; +//! 128-bit BND register (BND+). +class Bnd : public Reg { ASMJIT_DEFINE_FINAL_REG(Bnd, Reg, RegTraits) }; +//! 8192-bit TMM register (AMX). +class Tmm : public Reg { ASMJIT_DEFINE_FINAL_REG(Tmm, Reg, RegTraits) }; +//! RIP register (X86). +class Rip : public Reg { ASMJIT_DEFINE_FINAL_REG(Rip, Reg, RegTraits) }; + +//! \cond +ASMJIT_INLINE_NODEBUG GpbLo Gp::r8() const noexcept { return GpbLo(id()); } +ASMJIT_INLINE_NODEBUG GpbLo Gp::r8Lo() const noexcept { return GpbLo(id()); } +ASMJIT_INLINE_NODEBUG GpbHi Gp::r8Hi() const noexcept { return GpbHi(id()); } +ASMJIT_INLINE_NODEBUG Gpw Gp::r16() const noexcept { return Gpw(id()); } +ASMJIT_INLINE_NODEBUG Gpd Gp::r32() const noexcept { return Gpd(id()); } +ASMJIT_INLINE_NODEBUG Gpq Gp::r64() const noexcept { return Gpq(id()); } +ASMJIT_INLINE_NODEBUG Xmm Vec::xmm() const noexcept { return Xmm(id()); } +ASMJIT_INLINE_NODEBUG Ymm Vec::ymm() const noexcept { return Ymm(id()); } +ASMJIT_INLINE_NODEBUG Zmm Vec::zmm() const noexcept { return Zmm(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v128() const noexcept { return Xmm(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v256() const noexcept { return Ymm(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v512() const noexcept { return Zmm(id()); } +//! \endcond + +//! \namespace asmjit::x86::regs +//! +//! Registers provided by X86 and X64 ISAs are in both `asmjit::x86` and `asmjit::x86::regs` namespaces so they can +//! be included with using directive. For example `using namespace asmjit::x86::regs` would include all registers, +//! but not other X86-specific API, whereas `using namespace asmjit::x86` would include everything X86-specific. +#ifndef _DOXYGEN +namespace regs { +#endif + +//! Creates an 8-bit low GPB register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpbLo gpb(uint32_t rId) noexcept { return GpbLo(rId); } +//! Creates an 8-bit low GPB register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpbLo gpb_lo(uint32_t rId) noexcept { return GpbLo(rId); } +//! Creates an 8-bit high GPB register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpbHi gpb_hi(uint32_t rId) noexcept { return GpbHi(rId); } +//! Creates a 16-bit GPW register operand. +static ASMJIT_INLINE_NODEBUG constexpr Gpw gpw(uint32_t rId) noexcept { return Gpw(rId); } +//! Creates a 32-bit GPD register operand. +static ASMJIT_INLINE_NODEBUG constexpr Gpd gpd(uint32_t rId) noexcept { return Gpd(rId); } +//! Creates a 64-bit GPQ register operand (64-bit). +static ASMJIT_INLINE_NODEBUG constexpr Gpq gpq(uint32_t rId) noexcept { return Gpq(rId); } +//! Creates a 128-bit XMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Xmm xmm(uint32_t rId) noexcept { return Xmm(rId); } +//! Creates a 256-bit YMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Ymm ymm(uint32_t rId) noexcept { return Ymm(rId); } +//! Creates a 512-bit ZMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Zmm zmm(uint32_t rId) noexcept { return Zmm(rId); } +//! Creates a 64-bit Mm register operand. +static ASMJIT_INLINE_NODEBUG constexpr Mm mm(uint32_t rId) noexcept { return Mm(rId); } +//! Creates a 64-bit K register operand. +static ASMJIT_INLINE_NODEBUG constexpr KReg k(uint32_t rId) noexcept { return KReg(rId); } +//! Creates a 32-bit or 64-bit control register operand. +static ASMJIT_INLINE_NODEBUG constexpr CReg cr(uint32_t rId) noexcept { return CReg(rId); } +//! Creates a 32-bit or 64-bit debug register operand. +static ASMJIT_INLINE_NODEBUG constexpr DReg dr(uint32_t rId) noexcept { return DReg(rId); } +//! Creates an 80-bit st register operand. +static ASMJIT_INLINE_NODEBUG constexpr St st(uint32_t rId) noexcept { return St(rId); } +//! Creates a 128-bit bound register operand. +static ASMJIT_INLINE_NODEBUG constexpr Bnd bnd(uint32_t rId) noexcept { return Bnd(rId); } +//! Creates a TMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Tmm tmm(uint32_t rId) noexcept { return Tmm(rId); } + +static constexpr GpbLo al = GpbLo(Gp::kIdAx); +static constexpr GpbLo bl = GpbLo(Gp::kIdBx); +static constexpr GpbLo cl = GpbLo(Gp::kIdCx); +static constexpr GpbLo dl = GpbLo(Gp::kIdDx); +static constexpr GpbLo spl = GpbLo(Gp::kIdSp); +static constexpr GpbLo bpl = GpbLo(Gp::kIdBp); +static constexpr GpbLo sil = GpbLo(Gp::kIdSi); +static constexpr GpbLo dil = GpbLo(Gp::kIdDi); +static constexpr GpbLo r8b = GpbLo(Gp::kIdR8); +static constexpr GpbLo r9b = GpbLo(Gp::kIdR9); +static constexpr GpbLo r10b = GpbLo(Gp::kIdR10); +static constexpr GpbLo r11b = GpbLo(Gp::kIdR11); +static constexpr GpbLo r12b = GpbLo(Gp::kIdR12); +static constexpr GpbLo r13b = GpbLo(Gp::kIdR13); +static constexpr GpbLo r14b = GpbLo(Gp::kIdR14); +static constexpr GpbLo r15b = GpbLo(Gp::kIdR15); + +static constexpr GpbHi ah = GpbHi(Gp::kIdAx); +static constexpr GpbHi bh = GpbHi(Gp::kIdBx); +static constexpr GpbHi ch = GpbHi(Gp::kIdCx); +static constexpr GpbHi dh = GpbHi(Gp::kIdDx); + +static constexpr Gpw ax = Gpw(Gp::kIdAx); +static constexpr Gpw bx = Gpw(Gp::kIdBx); +static constexpr Gpw cx = Gpw(Gp::kIdCx); +static constexpr Gpw dx = Gpw(Gp::kIdDx); +static constexpr Gpw sp = Gpw(Gp::kIdSp); +static constexpr Gpw bp = Gpw(Gp::kIdBp); +static constexpr Gpw si = Gpw(Gp::kIdSi); +static constexpr Gpw di = Gpw(Gp::kIdDi); +static constexpr Gpw r8w = Gpw(Gp::kIdR8); +static constexpr Gpw r9w = Gpw(Gp::kIdR9); +static constexpr Gpw r10w = Gpw(Gp::kIdR10); +static constexpr Gpw r11w = Gpw(Gp::kIdR11); +static constexpr Gpw r12w = Gpw(Gp::kIdR12); +static constexpr Gpw r13w = Gpw(Gp::kIdR13); +static constexpr Gpw r14w = Gpw(Gp::kIdR14); +static constexpr Gpw r15w = Gpw(Gp::kIdR15); + +static constexpr Gpd eax = Gpd(Gp::kIdAx); +static constexpr Gpd ebx = Gpd(Gp::kIdBx); +static constexpr Gpd ecx = Gpd(Gp::kIdCx); +static constexpr Gpd edx = Gpd(Gp::kIdDx); +static constexpr Gpd esp = Gpd(Gp::kIdSp); +static constexpr Gpd ebp = Gpd(Gp::kIdBp); +static constexpr Gpd esi = Gpd(Gp::kIdSi); +static constexpr Gpd edi = Gpd(Gp::kIdDi); +static constexpr Gpd r8d = Gpd(Gp::kIdR8); +static constexpr Gpd r9d = Gpd(Gp::kIdR9); +static constexpr Gpd r10d = Gpd(Gp::kIdR10); +static constexpr Gpd r11d = Gpd(Gp::kIdR11); +static constexpr Gpd r12d = Gpd(Gp::kIdR12); +static constexpr Gpd r13d = Gpd(Gp::kIdR13); +static constexpr Gpd r14d = Gpd(Gp::kIdR14); +static constexpr Gpd r15d = Gpd(Gp::kIdR15); + +static constexpr Gpq rax = Gpq(Gp::kIdAx); +static constexpr Gpq rbx = Gpq(Gp::kIdBx); +static constexpr Gpq rcx = Gpq(Gp::kIdCx); +static constexpr Gpq rdx = Gpq(Gp::kIdDx); +static constexpr Gpq rsp = Gpq(Gp::kIdSp); +static constexpr Gpq rbp = Gpq(Gp::kIdBp); +static constexpr Gpq rsi = Gpq(Gp::kIdSi); +static constexpr Gpq rdi = Gpq(Gp::kIdDi); +static constexpr Gpq r8 = Gpq(Gp::kIdR8); +static constexpr Gpq r9 = Gpq(Gp::kIdR9); +static constexpr Gpq r10 = Gpq(Gp::kIdR10); +static constexpr Gpq r11 = Gpq(Gp::kIdR11); +static constexpr Gpq r12 = Gpq(Gp::kIdR12); +static constexpr Gpq r13 = Gpq(Gp::kIdR13); +static constexpr Gpq r14 = Gpq(Gp::kIdR14); +static constexpr Gpq r15 = Gpq(Gp::kIdR15); + +static constexpr Xmm xmm0 = Xmm(0); +static constexpr Xmm xmm1 = Xmm(1); +static constexpr Xmm xmm2 = Xmm(2); +static constexpr Xmm xmm3 = Xmm(3); +static constexpr Xmm xmm4 = Xmm(4); +static constexpr Xmm xmm5 = Xmm(5); +static constexpr Xmm xmm6 = Xmm(6); +static constexpr Xmm xmm7 = Xmm(7); +static constexpr Xmm xmm8 = Xmm(8); +static constexpr Xmm xmm9 = Xmm(9); +static constexpr Xmm xmm10 = Xmm(10); +static constexpr Xmm xmm11 = Xmm(11); +static constexpr Xmm xmm12 = Xmm(12); +static constexpr Xmm xmm13 = Xmm(13); +static constexpr Xmm xmm14 = Xmm(14); +static constexpr Xmm xmm15 = Xmm(15); +static constexpr Xmm xmm16 = Xmm(16); +static constexpr Xmm xmm17 = Xmm(17); +static constexpr Xmm xmm18 = Xmm(18); +static constexpr Xmm xmm19 = Xmm(19); +static constexpr Xmm xmm20 = Xmm(20); +static constexpr Xmm xmm21 = Xmm(21); +static constexpr Xmm xmm22 = Xmm(22); +static constexpr Xmm xmm23 = Xmm(23); +static constexpr Xmm xmm24 = Xmm(24); +static constexpr Xmm xmm25 = Xmm(25); +static constexpr Xmm xmm26 = Xmm(26); +static constexpr Xmm xmm27 = Xmm(27); +static constexpr Xmm xmm28 = Xmm(28); +static constexpr Xmm xmm29 = Xmm(29); +static constexpr Xmm xmm30 = Xmm(30); +static constexpr Xmm xmm31 = Xmm(31); + +static constexpr Ymm ymm0 = Ymm(0); +static constexpr Ymm ymm1 = Ymm(1); +static constexpr Ymm ymm2 = Ymm(2); +static constexpr Ymm ymm3 = Ymm(3); +static constexpr Ymm ymm4 = Ymm(4); +static constexpr Ymm ymm5 = Ymm(5); +static constexpr Ymm ymm6 = Ymm(6); +static constexpr Ymm ymm7 = Ymm(7); +static constexpr Ymm ymm8 = Ymm(8); +static constexpr Ymm ymm9 = Ymm(9); +static constexpr Ymm ymm10 = Ymm(10); +static constexpr Ymm ymm11 = Ymm(11); +static constexpr Ymm ymm12 = Ymm(12); +static constexpr Ymm ymm13 = Ymm(13); +static constexpr Ymm ymm14 = Ymm(14); +static constexpr Ymm ymm15 = Ymm(15); +static constexpr Ymm ymm16 = Ymm(16); +static constexpr Ymm ymm17 = Ymm(17); +static constexpr Ymm ymm18 = Ymm(18); +static constexpr Ymm ymm19 = Ymm(19); +static constexpr Ymm ymm20 = Ymm(20); +static constexpr Ymm ymm21 = Ymm(21); +static constexpr Ymm ymm22 = Ymm(22); +static constexpr Ymm ymm23 = Ymm(23); +static constexpr Ymm ymm24 = Ymm(24); +static constexpr Ymm ymm25 = Ymm(25); +static constexpr Ymm ymm26 = Ymm(26); +static constexpr Ymm ymm27 = Ymm(27); +static constexpr Ymm ymm28 = Ymm(28); +static constexpr Ymm ymm29 = Ymm(29); +static constexpr Ymm ymm30 = Ymm(30); +static constexpr Ymm ymm31 = Ymm(31); + +static constexpr Zmm zmm0 = Zmm(0); +static constexpr Zmm zmm1 = Zmm(1); +static constexpr Zmm zmm2 = Zmm(2); +static constexpr Zmm zmm3 = Zmm(3); +static constexpr Zmm zmm4 = Zmm(4); +static constexpr Zmm zmm5 = Zmm(5); +static constexpr Zmm zmm6 = Zmm(6); +static constexpr Zmm zmm7 = Zmm(7); +static constexpr Zmm zmm8 = Zmm(8); +static constexpr Zmm zmm9 = Zmm(9); +static constexpr Zmm zmm10 = Zmm(10); +static constexpr Zmm zmm11 = Zmm(11); +static constexpr Zmm zmm12 = Zmm(12); +static constexpr Zmm zmm13 = Zmm(13); +static constexpr Zmm zmm14 = Zmm(14); +static constexpr Zmm zmm15 = Zmm(15); +static constexpr Zmm zmm16 = Zmm(16); +static constexpr Zmm zmm17 = Zmm(17); +static constexpr Zmm zmm18 = Zmm(18); +static constexpr Zmm zmm19 = Zmm(19); +static constexpr Zmm zmm20 = Zmm(20); +static constexpr Zmm zmm21 = Zmm(21); +static constexpr Zmm zmm22 = Zmm(22); +static constexpr Zmm zmm23 = Zmm(23); +static constexpr Zmm zmm24 = Zmm(24); +static constexpr Zmm zmm25 = Zmm(25); +static constexpr Zmm zmm26 = Zmm(26); +static constexpr Zmm zmm27 = Zmm(27); +static constexpr Zmm zmm28 = Zmm(28); +static constexpr Zmm zmm29 = Zmm(29); +static constexpr Zmm zmm30 = Zmm(30); +static constexpr Zmm zmm31 = Zmm(31); + +static constexpr Mm mm0 = Mm(0); +static constexpr Mm mm1 = Mm(1); +static constexpr Mm mm2 = Mm(2); +static constexpr Mm mm3 = Mm(3); +static constexpr Mm mm4 = Mm(4); +static constexpr Mm mm5 = Mm(5); +static constexpr Mm mm6 = Mm(6); +static constexpr Mm mm7 = Mm(7); + +static constexpr KReg k0 = KReg(0); +static constexpr KReg k1 = KReg(1); +static constexpr KReg k2 = KReg(2); +static constexpr KReg k3 = KReg(3); +static constexpr KReg k4 = KReg(4); +static constexpr KReg k5 = KReg(5); +static constexpr KReg k6 = KReg(6); +static constexpr KReg k7 = KReg(7); + +static constexpr SReg no_seg = SReg(SReg::kIdNone); +static constexpr SReg es = SReg(SReg::kIdEs); +static constexpr SReg cs = SReg(SReg::kIdCs); +static constexpr SReg ss = SReg(SReg::kIdSs); +static constexpr SReg ds = SReg(SReg::kIdDs); +static constexpr SReg fs = SReg(SReg::kIdFs); +static constexpr SReg gs = SReg(SReg::kIdGs); + +static constexpr CReg cr0 = CReg(0); +static constexpr CReg cr1 = CReg(1); +static constexpr CReg cr2 = CReg(2); +static constexpr CReg cr3 = CReg(3); +static constexpr CReg cr4 = CReg(4); +static constexpr CReg cr5 = CReg(5); +static constexpr CReg cr6 = CReg(6); +static constexpr CReg cr7 = CReg(7); +static constexpr CReg cr8 = CReg(8); +static constexpr CReg cr9 = CReg(9); +static constexpr CReg cr10 = CReg(10); +static constexpr CReg cr11 = CReg(11); +static constexpr CReg cr12 = CReg(12); +static constexpr CReg cr13 = CReg(13); +static constexpr CReg cr14 = CReg(14); +static constexpr CReg cr15 = CReg(15); + +static constexpr DReg dr0 = DReg(0); +static constexpr DReg dr1 = DReg(1); +static constexpr DReg dr2 = DReg(2); +static constexpr DReg dr3 = DReg(3); +static constexpr DReg dr4 = DReg(4); +static constexpr DReg dr5 = DReg(5); +static constexpr DReg dr6 = DReg(6); +static constexpr DReg dr7 = DReg(7); +static constexpr DReg dr8 = DReg(8); +static constexpr DReg dr9 = DReg(9); +static constexpr DReg dr10 = DReg(10); +static constexpr DReg dr11 = DReg(11); +static constexpr DReg dr12 = DReg(12); +static constexpr DReg dr13 = DReg(13); +static constexpr DReg dr14 = DReg(14); +static constexpr DReg dr15 = DReg(15); + +static constexpr St st0 = St(0); +static constexpr St st1 = St(1); +static constexpr St st2 = St(2); +static constexpr St st3 = St(3); +static constexpr St st4 = St(4); +static constexpr St st5 = St(5); +static constexpr St st6 = St(6); +static constexpr St st7 = St(7); + +static constexpr Bnd bnd0 = Bnd(0); +static constexpr Bnd bnd1 = Bnd(1); +static constexpr Bnd bnd2 = Bnd(2); +static constexpr Bnd bnd3 = Bnd(3); + +static constexpr Tmm tmm0 = Tmm(0); +static constexpr Tmm tmm1 = Tmm(1); +static constexpr Tmm tmm2 = Tmm(2); +static constexpr Tmm tmm3 = Tmm(3); +static constexpr Tmm tmm4 = Tmm(4); +static constexpr Tmm tmm5 = Tmm(5); +static constexpr Tmm tmm6 = Tmm(6); +static constexpr Tmm tmm7 = Tmm(7); + +static constexpr Rip rip = Rip(0); + +#ifndef _DOXYGEN +} // {regs} + +// Make `x86::regs` accessible through `x86` namespace as well. +using namespace regs; +#endif + +//! Memory operand specific to X86 and X86_64 architecture. +class Mem : public BaseMem { +public: + //! \name Constants + //! \{ + + //! Additional bits of operand's signature used by `x86::Mem`. + enum AdditionalBits : uint32_t { + // Memory address type (2 bits). + // |........|........|XX......|........| + kSignatureMemAddrTypeShift = 14, + kSignatureMemAddrTypeMask = 0x03u << kSignatureMemAddrTypeShift, + + // Memory shift amount (2 bits). + // |........|......XX|........|........| + kSignatureMemShiftValueShift = 16, + kSignatureMemShiftValueMask = 0x03u << kSignatureMemShiftValueShift, + + // Memory segment reg (3 bits). + // |........|...XXX..|........|........| + kSignatureMemSegmentShift = 18, + kSignatureMemSegmentMask = 0x07u << kSignatureMemSegmentShift, + + // Memory broadcast type (3 bits). + // |........|XXX.....|........|........| + kSignatureMemBroadcastShift = 21, + kSignatureMemBroadcastMask = 0x7u << kSignatureMemBroadcastShift + }; + + //! Address type. + enum class AddrType : uint32_t { + //! Default address type, Assembler will select the best type when necessary. + kDefault = 0, + //! Absolute address type. + kAbs = 1, + //! Relative address type. + kRel = 2, + + //! Maximum value of `AddrType`. + kMaxValue = kRel + }; + + //! Memory broadcast type. + enum class Broadcast : uint32_t { + //! No broadcast (regular memory operand). + kNone = 0, + //! Broadcast {1to2}. + k1To2 = 1, + //! Broadcast {1to4}. + k1To4 = 2, + //! Broadcast {1to8}. + k1To8 = 3, + //! Broadcast {1to16}. + k1To16 = 4, + //! Broadcast {1to32}. + k1To32 = 5, + //! Broadcast {1to64}. + k1To64 = 6, + + //! Maximum value of `Broadcast`. + kMaxValue = k1To64 + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default `Mem` operand that points to [0]. + ASMJIT_INLINE_NODEBUG constexpr Mem() noexcept + : BaseMem() {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Mem& other) noexcept + : BaseMem(other) {} + + ASMJIT_INLINE_NODEBUG explicit Mem(Globals::NoInit_) noexcept + : BaseMem(Globals::NoInit) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Signature& signature, uint32_t baseId, uint32_t indexId, int32_t offset) noexcept + : BaseMem(signature, baseId, indexId, offset) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Label& base, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(RegType::kLabelTag) | + Signature::fromSize(size) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Label& base, const BaseReg& index, uint32_t shift, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(RegType::kLabelTag) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(shift) | + Signature::fromSize(size) | + signature, base.id(), index.id(), off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromSize(size) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, const BaseReg& index, uint32_t shift, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(shift) | + Signature::fromSize(size) | + signature, base.id(), index.id(), off) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(uint64_t base, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromSize(size) | + signature, uint32_t(base >> 32), 0, int32_t(uint32_t(base & 0xFFFFFFFFu))) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(uint64_t base, const BaseReg& index, uint32_t shift = 0, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(shift) | + Signature::fromSize(size) | + signature, uint32_t(base >> 32), index.id(), int32_t(uint32_t(base & 0xFFFFFFFFu))) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Mem& operator=(const Mem& other) noexcept = default; + + //! \} + + //! \name Clone + //! \{ + + //! Clones the memory operand. + ASMJIT_INLINE_NODEBUG constexpr Mem clone() const noexcept { return Mem(*this); } + + //! Creates a copy of this memory operand adjusted by `off`. + inline Mem cloneAdjusted(int64_t off) const noexcept { + Mem result(*this); + result.addOffset(off); + return result; + } + + //! Creates a copy of this memory operand resized to `size`. + inline Mem cloneResized(uint32_t size) const noexcept { + Mem result(*this); + result.setSize(size); + return result; + } + + //! Creates a copy of this memory operand with a broadcast `bcst`. + ASMJIT_INLINE_NODEBUG constexpr Mem cloneBroadcasted(Broadcast bcst) const noexcept { + return Mem((_signature & ~Signature{kSignatureMemBroadcastMask}) | Signature::fromValue(bcst), _baseId, _data[0], int32_t(_data[1])); + } + + //! \} + + //! \name Base & Index + //! \{ + + //! Converts memory `baseType` and `baseId` to `x86::Reg` instance. + //! + //! The memory must have a valid base register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg baseReg() const noexcept { return Reg::fromTypeAndId(baseType(), baseId()); } + + //! Converts memory `indexType` and `indexId` to `x86::Reg` instance. + //! + //! The memory must have a valid index register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg indexReg() const noexcept { return Reg::fromTypeAndId(indexType(), indexId()); } + + using BaseMem::setIndex; + + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index, uint32_t shift) noexcept { + setIndex(index); + setShift(shift); + } + + //! \} + + //! \name Memory Size + //! \{ + + //! Tests whether the memory operand specifies a size (i.e. the size is not zero). + ASMJIT_INLINE_NODEBUG constexpr bool hasSize() const noexcept { return _signature.hasField(); } + //! Tests whether the memory operand size matches size `s`. + ASMJIT_INLINE_NODEBUG constexpr bool hasSize(uint32_t s) const noexcept { return size() == s; } + + //! Returns the size of the memory operand in bytes. + //! + //! \note Most instructions would deduce the size of the memory operand, so in most cases it's expected that the + //! returned value would be zero. However, some instruction require the size to select between multiple variations, + //! so in some cases size is required would be non-zero (for example `inc [mem], immediate` requires size to + //! distinguish between 8-bit, 16-bit, 32-bit, and 64-bit increments. + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } + + //! \} + + //! \name Address Type + //! \{ + + //! Returns the address type of the memory operand. + //! + //! By default, address type of newly created memory operands is always \ref AddrType::kDefault. + ASMJIT_INLINE_NODEBUG constexpr AddrType addrType() const noexcept { return (AddrType)_signature.getField(); } + //! Sets the address type to `addrType`. + ASMJIT_INLINE_NODEBUG void setAddrType(AddrType addrType) noexcept { _signature.setField(uint32_t(addrType)); } + //! Resets the address type to \ref AddrType::kDefault. + ASMJIT_INLINE_NODEBUG void resetAddrType() noexcept { _signature.setField(uint32_t(AddrType::kDefault)); } + + //! Tests whether the address type is \ref AddrType::kAbs. + ASMJIT_INLINE_NODEBUG constexpr bool isAbs() const noexcept { return addrType() == AddrType::kAbs; } + //! Sets the address type to \ref AddrType::kAbs. + ASMJIT_INLINE_NODEBUG void setAbs() noexcept { setAddrType(AddrType::kAbs); } + + //! Tests whether the address type is \ref AddrType::kRel. + ASMJIT_INLINE_NODEBUG constexpr bool isRel() const noexcept { return addrType() == AddrType::kRel; } + //! Sets the address type to \ref AddrType::kRel. + ASMJIT_INLINE_NODEBUG void setRel() noexcept { setAddrType(AddrType::kRel); } + + //! \} + + //! \name Segment + //! \{ + + //! Tests whether the memory operand has a segment override. + ASMJIT_INLINE_NODEBUG constexpr bool hasSegment() const noexcept { return _signature.hasField(); } + //! Returns the associated segment override as `SReg` operand. + ASMJIT_INLINE_NODEBUG constexpr SReg segment() const noexcept { return SReg(segmentId()); } + //! Returns segment override register id, see `SReg::Id`. + ASMJIT_INLINE_NODEBUG constexpr uint32_t segmentId() const noexcept { return _signature.getField(); } + + //! Sets the segment override to `seg`. + ASMJIT_INLINE_NODEBUG void setSegment(const SReg& seg) noexcept { setSegment(seg.id()); } + //! Sets the segment override to `id`. + ASMJIT_INLINE_NODEBUG void setSegment(uint32_t rId) noexcept { _signature.setField(rId); } + //! Resets the segment override. + ASMJIT_INLINE_NODEBUG void resetSegment() noexcept { _signature.setField(0); } + + //! \} + + //! \name Shift + //! \{ + + //! Tests whether the memory operand has shift (aka scale) value. + ASMJIT_INLINE_NODEBUG constexpr bool hasShift() const noexcept { return _signature.hasField(); } + //! Returns the memory operand's shift (aka scale) value. + ASMJIT_INLINE_NODEBUG constexpr uint32_t shift() const noexcept { return _signature.getField(); } + //! Sets the memory operand's shift (aka scale) value. + ASMJIT_INLINE_NODEBUG void setShift(uint32_t shift) noexcept { _signature.setField(shift); } + //! Resets the memory operand's shift (aka scale) value to zero. + ASMJIT_INLINE_NODEBUG void resetShift() noexcept { _signature.setField(0); } + + //! \} + + //! \name Broadcast + //! \{ + + //! Tests whether the memory operand has broadcast {1tox}. + ASMJIT_INLINE_NODEBUG constexpr bool hasBroadcast() const noexcept { return _signature.hasField(); } + //! Returns the memory operand's broadcast. + ASMJIT_INLINE_NODEBUG constexpr Broadcast getBroadcast() const noexcept { return (Broadcast)_signature.getField(); } + //! Sets the memory operand's broadcast. + ASMJIT_INLINE_NODEBUG void setBroadcast(Broadcast b) noexcept { _signature.setField(uint32_t(b)); } + //! Resets the memory operand's broadcast to none. + ASMJIT_INLINE_NODEBUG void resetBroadcast() noexcept { _signature.setField(0); } + + //! Returns a new `Mem` without a broadcast (the possible broadcast is cleared). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to1() const noexcept { return cloneBroadcasted(Broadcast::kNone); } + //! Returns a new `Mem` with {1to2} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to2() const noexcept { return cloneBroadcasted(Broadcast::k1To2); } + //! Returns a new `Mem` with {1to4} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to4() const noexcept { return cloneBroadcasted(Broadcast::k1To4); } + //! Returns a new `Mem` with {1to8} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to8() const noexcept { return cloneBroadcasted(Broadcast::k1To8); } + //! Returns a new `Mem` with {1to16} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to16() const noexcept { return cloneBroadcasted(Broadcast::k1To16); } + //! Returns a new `Mem` with {1to32} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to32() const noexcept { return cloneBroadcasted(Broadcast::k1To32); } + //! Returns a new `Mem` with {1to64} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to64() const noexcept { return cloneBroadcasted(Broadcast::k1To64); } + + //! \} +}; + +//! Creates `[base.reg + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, offset, size); +} +//! Creates `[base.reg + (index << shift) + offset]` memory operand (scalar index). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} +//! Creates `[base.reg + (index << shift) + offset]` memory operand (vector index). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} + +//! Creates `[base + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, offset, size); +} +//! Creates `[base + (index << shift) + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} +//! Creates `[base + (index << shift) + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} + +//! Creates `[rip + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Rip& rip_, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(rip_, offset, size); +} + +//! Creates `[base]` absolute memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base, uint32_t size = 0) noexcept { + return Mem(base, size); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base, const Reg& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base, const Vec& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size); +} + +//! Creates `[base]` absolute memory operand (absolute). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_abs(uint64_t base, uint32_t size = 0) noexcept { + return Mem(base, size, OperandSignature::fromValue(Mem::AddrType::kAbs)); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand (absolute). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_abs(uint64_t base, const Reg& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kAbs)); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand (absolute). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_abs(uint64_t base, const Vec& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kAbs)); +} + +//! Creates `[base]` relative memory operand (relative). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_rel(uint64_t base, uint32_t size = 0) noexcept { + return Mem(base, size, OperandSignature::fromValue(Mem::AddrType::kRel)); +} +//! Creates `[base + (index.reg << shift)]` relative memory operand (relative). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_rel(uint64_t base, const Reg& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kRel)); +} +//! Creates `[base + (index.reg << shift)]` relative memory operand (relative). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_rel(uint64_t base, const Vec& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kRel)); +} + +#define ASMJIT_MEM_PTR(FUNC, SIZE) \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Gp& base, int32_t offset = 0) noexcept \ + { return Mem(base, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Gp& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0) noexcept \ + { return Mem(base, index, shift, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Gp& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0) noexcept \ + { return Mem(base, index, shift, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Label& base, int32_t offset = 0) noexcept \ + { return Mem(base, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Label& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0) noexcept \ + { return Mem(base, index, shift, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Rip& rip_, int32_t offset = 0) noexcept \ + { return Mem(rip_, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + uint64_t base) noexcept \ + { return Mem(base, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + uint64_t base, const Gp& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + uint64_t base, const Vec& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_abs( \ + uint64_t base) noexcept \ + { return Mem(base, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kAbs)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_abs( \ + uint64_t base, const Gp& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kAbs)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_abs( \ + uint64_t base, const Vec& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kAbs)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_rel( \ + uint64_t base) noexcept \ + { return Mem(base, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kRel)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_rel( \ + uint64_t base, const Gp& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kRel)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_rel( \ + uint64_t base, const Vec& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kRel)); } + +// Definition of memory operand constructors that use platform independent naming. +ASMJIT_MEM_PTR(ptr_8, 1) +ASMJIT_MEM_PTR(ptr_16, 2) +ASMJIT_MEM_PTR(ptr_32, 4) +ASMJIT_MEM_PTR(ptr_48, 6) +ASMJIT_MEM_PTR(ptr_64, 8) +ASMJIT_MEM_PTR(ptr_80, 10) +ASMJIT_MEM_PTR(ptr_128, 16) +ASMJIT_MEM_PTR(ptr_256, 32) +ASMJIT_MEM_PTR(ptr_512, 64) + +// Definition of memory operand constructors that use X86-specific convention. +ASMJIT_MEM_PTR(byte_ptr, 1) +ASMJIT_MEM_PTR(word_ptr, 2) +ASMJIT_MEM_PTR(dword_ptr, 4) +ASMJIT_MEM_PTR(fword_ptr, 6) +ASMJIT_MEM_PTR(qword_ptr, 8) +ASMJIT_MEM_PTR(tbyte_ptr, 10) +ASMJIT_MEM_PTR(tword_ptr, 10) +ASMJIT_MEM_PTR(oword_ptr, 16) +ASMJIT_MEM_PTR(dqword_ptr, 16) +ASMJIT_MEM_PTR(qqword_ptr, 32) +ASMJIT_MEM_PTR(xmmword_ptr, 16) +ASMJIT_MEM_PTR(ymmword_ptr, 32) +ASMJIT_MEM_PTR(zmmword_ptr, 64) + +#undef ASMJIT_MEM_PTR + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +//! \cond INTERNAL +ASMJIT_BEGIN_NAMESPACE +ASMJIT_DEFINE_TYPE_ID(x86::Gpb, TypeId::kInt8); +ASMJIT_DEFINE_TYPE_ID(x86::Gpw, TypeId::kInt16); +ASMJIT_DEFINE_TYPE_ID(x86::Gpd, TypeId::kInt32); +ASMJIT_DEFINE_TYPE_ID(x86::Gpq, TypeId::kInt64); +ASMJIT_DEFINE_TYPE_ID(x86::Mm , TypeId::kMmx64); +ASMJIT_DEFINE_TYPE_ID(x86::Xmm, TypeId::kInt32x4); +ASMJIT_DEFINE_TYPE_ID(x86::Ymm, TypeId::kInt32x8); +ASMJIT_DEFINE_TYPE_ID(x86::Zmm, TypeId::kInt32x16); +ASMJIT_END_NAMESPACE +//! \endcond + +#endif // ASMJIT_X86_X86OPERAND_H_INCLUDED diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/common.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/common.h new file mode 100644 index 0000000000000000000000000000000000000000..5058fdff5860a50d2cd5299eb3abad9e2a937aaa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/common.h @@ -0,0 +1,61 @@ +#ifndef CAFFE2_CORE_COMMON_H_ +#define CAFFE2_CORE_COMMON_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __APPLE__ +#include +#endif + +#if defined(_MSC_VER) +#include +#else +#include +#endif + +// Macros used during the build of this caffe2 instance. This header file +// is automatically generated by the cmake script during build. +#include "caffe2/core/macros.h" + +#include + +namespace caffe2 { + +// Using statements for common classes that we refer to in caffe2 very often. +// Note that we only place it inside caffe2 so the global namespace is not +// polluted. +/* using override */ +using std::set; +using std::string; +using std::unique_ptr; +using std::vector; + +// Define alignment macro that is cross platform +#if (defined _MSC_VER && !defined NOMINMAX) +#define NOMINMAX +#endif + +using std::make_unique; + +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +using ::round; +#else +using std::round; +#endif // defined(__ANDROID__) && !defined(__NDK_MAJOR__) + +// Returns which setting Caffe2 was configured and built with (exported from +// CMake) +TORCH_API const std::map& GetBuildOptions(); + +} // namespace caffe2 + +#endif // CAFFE2_CORE_COMMON_H_ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/macros.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..6477ed38f5cb5426ed1b397ac0b49a004a316180 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/macros.h @@ -0,0 +1,70 @@ +// Automatically generated header file for caffe2 macros. These +// macros are used to build the Caffe2 binary, and if you are +// building a dependent library, they will need to be set as well +// for your program to link correctly. + +#pragma once + +#define CAFFE2_BUILD_SHARED_LIBS +/* #undef CAFFE2_FORCE_FALLBACK_CUDA_MPI */ +/* #undef CAFFE2_HAS_MKL_DNN */ +/* #undef CAFFE2_HAS_MKL_SGEMM_PACK */ +#define CAFFE2_PERF_WITH_AVX +#define CAFFE2_PERF_WITH_AVX2 +/* #undef CAFFE2_THREADPOOL_MAIN_IMBALANCE */ +/* #undef CAFFE2_THREADPOOL_STATS */ +/* #undef CAFFE2_USE_ACCELERATE */ +#define CAFFE2_USE_CUDNN +/* #undef CAFFE2_USE_EIGEN_FOR_BLAS */ +/* #undef CAFFE2_USE_FBCODE */ +/* #undef CAFFE2_USE_GOOGLE_GLOG */ +/* #undef CAFFE2_USE_LITE_PROTO */ +#define CAFFE2_USE_MKL +#define USE_MKLDNN +/* #undef CAFFE2_USE_NVTX */ +/* #undef CAFFE2_USE_ITT */ + +#ifndef EIGEN_MPL2_ONLY +#define EIGEN_MPL2_ONLY +#endif + +// Useful build settings that are recorded in the compiled binary +// torch.__config__.show() +#define CAFFE2_BUILD_STRINGS { \ + {"TORCH_VERSION", "2.8.0"}, \ + {"CXX_COMPILER", "/opt/rh/gcc-toolset-13/root/usr/bin/c++"}, \ + {"CXX_FLAGS", " -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow"}, \ + {"BUILD_TYPE", "Release"}, \ + {"BLAS_INFO", "mkl"}, \ + {"LAPACK_INFO", "mkl"}, \ + {"USE_CUDA", "ON"}, \ + {"USE_ROCM", "OFF"}, \ + {"CUDA_VERSION", "12.8"}, \ + {"ROCM_VERSION", ""}, \ + {"USE_CUDNN", "ON"}, \ + {"COMMIT_SHA", "a1cb3cc05d46d198467bebbb6e8fba50a325d4e7"}, \ + {"CUDNN_VERSION", "9.8.0"}, \ + {"USE_NCCL", "1"}, \ + {"USE_MPI", "OFF"}, \ + {"USE_GFLAGS", "OFF"}, \ + {"USE_GLOG", "OFF"}, \ + {"USE_GLOO", "ON"}, \ + {"USE_NNPACK", "ON"}, \ + {"USE_OPENMP", "ON"}, \ + {"FORCE_FALLBACK_CUDA_MPI", ""}, \ + {"HAS_MKL_DNN", ""}, \ + {"HAS_MKL_SGEMM_PACK", ""}, \ + {"PERF_WITH_AVX", "1"}, \ + {"PERF_WITH_AVX2", "1"}, \ + {"USE_ACCELERATE", ""}, \ + {"USE_EIGEN_FOR_BLAS", ""}, \ + {"USE_LITE_PROTO", ""}, \ + {"USE_MKL", "ON"}, \ + {"USE_MKLDNN", "ON"}, \ + {"USE_NVTX", ""}, \ + {"USE_ITT", ""}, \ + {"USE_ROCM_KERNEL_ASSERT", "OFF"}, \ + {"USE_CUSPARSELT", "1"}, \ + {"USE_XPU", "OFF"}, \ + {"USE_XCCL", "OFF"}, \ +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/timer.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/timer.h new file mode 100644 index 0000000000000000000000000000000000000000..a0384b0dbdbd025fcb7cb22dda84c4a75e209120 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/core/timer.h @@ -0,0 +1,48 @@ +#ifndef CAFFE2_CORE_TIMER_H_ +#define CAFFE2_CORE_TIMER_H_ + +#include + +#include "caffe2/core/common.h" + +namespace caffe2 { + +/** + * @brief A simple timer object for measuring time. + * + * This is a minimal class around a std::chrono::high_resolution_clock that + * serves as a utility class for testing code. + */ +class Timer { + public: + typedef std::chrono::high_resolution_clock clock; + typedef std::chrono::nanoseconds ns; + Timer() { Start(); } + /** + * @brief Starts a timer. + */ + inline void Start() { start_time_ = clock::now(); } + inline float NanoSeconds() { + return static_cast( + std::chrono::duration_cast(clock::now() - start_time_).count()); + } + /** + * @brief Returns the elapsed time in milliseconds. + */ + inline float MilliSeconds() { return NanoSeconds() / 1000000.f; } + /** + * @brief Returns the elapsed time in microseconds. + */ + inline float MicroSeconds() { return NanoSeconds() / 1000.f; } + /** + * @brief Returns the elapsed time in seconds. + */ + inline float Seconds() { return NanoSeconds() / 1000000000.f; } + + protected: + std::chrono::time_point start_time_; + C10_DISABLE_COPY_AND_ASSIGN(Timer); +}; +} + +#endif // CAFFE2_CORE_TIMER_H_ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/batch_box_cox_vec.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/batch_box_cox_vec.h new file mode 100644 index 0000000000000000000000000000000000000000..ed2e83062d1071714e8fd6de7a424fd434425369 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/batch_box_cox_vec.h @@ -0,0 +1,311 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "vectorizer.h" +#include + +namespace caffe2::details { + +namespace { +void TileIndicesInPlace(std::vector& v, const std::size_t D, const std::size_t K) { + auto n = v.size(); + v.resize(K * n); + for (const auto k : c10::irange(1, K)) { + for (const auto j : c10::irange(n)) { + v[k * n + j] = v[j] + k * D; + } + } +} + +// MKL VML function templates. +template +void PackV(const int N, const T* a, const int* ia, T* y); +template +void UnpackV(const int N, const T* a, T* y, const int* iy); + +#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ + template <> \ + void PackV(const int N, const T* a, const int* ia, T* y) { \ + OriginalFunc(N, a, ia, y); \ + } +DELEGATE_PACKV_FUNCTION(float, vsPackV) +DELEGATE_PACKV_FUNCTION(double, vdPackV) +#undef DELEGATE_PACKV_FUNCTION + +#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \ + template <> \ + void UnpackV(const int N, const T* a, T* y, const int* iy) { \ + OriginalFunc(N, a, y, iy); \ + } +DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) +DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) +#undef DELEGATE_UNPACKV_FUNCTION + +#ifndef FAST_VECTORIZED_KERNEL +template +void box_cox_zero_lambda( + size_t D, + const T* const self_data, + const T* const lambda2_data, + T k_eps, + T* const output_data) { + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(self_data + j); + auto lambda2 = Vec::loadu(lambda2_data + j); + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps_vec); + auto res = max.log(); + res.store(output_data + j); + } + for ( ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); + } +} + +template +void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(data_ptr + j); + auto lambda2 = Vec::loadu(lambda2_ptr + j); + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps_vec); + auto lambda1 = Vec::loadu(lambda1_ptr + j); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); + auto pow = max.pow(lambda1); + auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); + res.store(out + j); + } + for ( ;j < D; ++j) { + auto sum = data_ptr[j] + lambda2_ptr[j]; + auto max = std::max(sum, k_eps); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); + auto pow = std::pow(max, lambda1_ptr[j]); + out[j] = pow * lambda_over_1 - lambda_over_1; + } +} +#else +template +void box_cox_zero_lambda( + size_t D, + const T* const self_data, + const T* const lambda2_data, + T k_eps, + T* const output_data) { + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); + } +} + +template +void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + FAST_MATH + auto sum = data_ptr[j] + lambda2_ptr[j]; + auto max = std::max(sum, k_eps); + auto lamda1 = lambda1_ptr[j]; + auto lambda_over_1 = 1 / lamda1; + if constexpr (std::is_same::value) { + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + } + auto pow = std::pow(max, lamda1); + out[j] = pow * lambda_over_1 - lambda_over_1; + } +} +#endif // FAST_VECTORIZED_KERNEL + +template +void box_cox_mixed_lambda( + const T* const self_data, + const std::vector& nonzeros, + const std::vector& zeros, + const T* const lambda1, + const T* const lambda2, + const T* const lambda2_z_, + T k_eps, + T* const buffer, + T* const output_data) { + PackV(nonzeros.size(), self_data, nonzeros.data(), buffer); + box_cox_nonzero_lambda( + nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer); + UnpackV(nonzeros.size(), buffer, output_data, nonzeros.data()); + + PackV(zeros.size(), self_data, zeros.data(), buffer); + box_cox_zero_lambda( + zeros.size(), buffer, lambda2_z_, k_eps, buffer); + UnpackV(zeros.size(), buffer, output_data, zeros.data()); +} + +template +void TileArrayIntoVector( + const T* const a, + const size_t D, + const int K, + std::vector& b) { + b.resize(K * D); + for (const auto k : c10::irange(K)) { + std::copy(a, a + D, b.begin() + k * D); + } +} + +template +void compute_batch_box_cox_vec_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const T* self_data, + const T* __restrict lambda1_data, + const T* __restrict lambda2_data, + T* output_data) { + constexpr T k_eps = static_cast(1e-6); + + FOLLY_DECLARE_REUSED(zeros, std::vector); + FOLLY_DECLARE_REUSED(nonzeros, std::vector); + // Don't bother calling reserve; calls after the first will get a + // correctly-sized allocation anyway. + for (const auto j : c10::irange(D)) { + if (lambda1_data[j] == 0) { + zeros.push_back(j); + } else { + nonzeros.push_back(j); + } + } + + // Process K rows at a time for effective vectorization with small rows. + const auto K = std::min(N, (block_size + D - 1) / D); + + FOLLY_DECLARE_REUSED(lambda1_, std::vector); + FOLLY_DECLARE_REUSED(lambda2_, std::vector); + FOLLY_DECLARE_REUSED(lambda2_z_, std::vector); + + if (nonzeros.size() == D) { + // ((x + lambda2)^lambda1 - 1)/lambda1, if lambda1 != 0 + size_t i = 0; + if (K > 1) { + TileArrayIntoVector(lambda1_data, D, K, lambda1_); + TileArrayIntoVector(lambda2_data, D, K, lambda2_); + DCHECK_EQ(K * D, lambda1_.size()); + DCHECK_EQ(K * D, lambda2_.size()); + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_nonzero_lambda( + K * D, + self_data, + lambda1_.data(), + lambda2_.data(), + k_eps, + output_data); + } + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_nonzero_lambda( + D, self_data, lambda1_data, lambda2_data, k_eps, output_data); + } + } else if (zeros.size() == D) { + // ln(x + lambda2), if lambda1 == 0 + size_t i = 0; + if (K > 1) { + TileArrayIntoVector(lambda2_data, D, K, lambda2_z_); + DCHECK_EQ(K * D, lambda2_z_.size()); + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_zero_lambda( + K * D, self_data, lambda2_z_.data(), k_eps, output_data); + } + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_zero_lambda( + D, self_data, lambda2_data, k_eps, output_data); + } + } else { + // mix zeros and nonzeros + const size_t n = nonzeros.size(); + if (K > 1) { + TileIndicesInPlace(nonzeros, 0, K); + TileIndicesInPlace(zeros, 0, K); + } + + FOLLY_DECLARE_REUSED(buffer, std::vector); + + buffer.resize(std::max(nonzeros.size(), zeros.size())); + lambda1_.resize(nonzeros.size()); + lambda2_.resize(nonzeros.size()); + lambda2_z_.resize(zeros.size()); + PackV(nonzeros.size(), lambda1_data, nonzeros.data(), lambda1_.data()); + PackV(nonzeros.size(), lambda2_data, nonzeros.data(), lambda2_.data()); + PackV(zeros.size(), lambda2_data, zeros.data(), lambda2_z_.data()); + + size_t i = 0; + if (K > 1) { + // Truncate to original size, and re-tile with offsets this time. + nonzeros.resize(n); + DCHECK_GT(D, n); + zeros.resize(D - n); + TileIndicesInPlace(nonzeros, D, K); + TileIndicesInPlace(zeros, D, K); + DCHECK_EQ(nonzeros.size(), lambda1_.size()); + DCHECK_EQ(nonzeros.size(), lambda2_.size()); + DCHECK_EQ(zeros.size(), lambda2_z_.size()); + + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_mixed_lambda( + self_data, + nonzeros, + zeros, + lambda1_.data(), + lambda2_.data(), + lambda2_z_.data(), + k_eps, + buffer.data(), + output_data); + } + // Truncate to original size. + nonzeros.resize(n); + zeros.resize(D - n); + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_mixed_lambda( + self_data, + nonzeros, + zeros, + lambda1_.data(), + lambda2_.data(), + lambda2_z_.data(), + k_eps, + buffer.data(), + output_data); + } + } +} +} // namespace + +} // namespace caffe2::details diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/common.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/common.h new file mode 100644 index 0000000000000000000000000000000000000000..6e069861b28d21d24893c188b35e06c7a093f5bd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/common.h @@ -0,0 +1,140 @@ +// !!!! PLEASE READ !!!! +// Minimize (transitively) included headers from _avx*.cc because some of the +// functions defined in the headers compiled with platform dependent compiler +// options can be reused by other translation units generating illegal +// instruction run-time error. + +// Common utilities for writing performance kernels and easy dispatching of +// different backends. +/* +The general workflow shall be as follows, say we want to +implement a functionality called void foo(int a, float b). + +In foo.h, do: + void foo(int a, float b); + +In foo_avx512.cc, do: + void foo__avx512(int a, float b) { + [actual avx512 implementation] + } + +In foo_avx2.cc, do: + void foo__avx2(int a, float b) { + [actual avx2 implementation] + } + +In foo_avx.cc, do: + void foo__avx(int a, float b) { + [actual avx implementation] + } + +In foo.cc, do: + // The base implementation should *always* be provided. + void foo__base(int a, float b) { + [base, possibly slow implementation] + } + decltype(foo__base) foo__avx512; + decltype(foo__base) foo__avx2; + decltype(foo__base) foo__avx; + void foo(int a, float b) { + // You should always order things by their preference, faster + // implementations earlier in the function. + AVX512_DO(foo, a, b); + AVX2_DO(foo, a, b); + AVX_DO(foo, a, b); + BASE_DO(foo, a, b); + } + +*/ +// Details: this functionality basically covers the cases for both build time +// and run time architecture support. +// +// During build time: +// The build system should provide flags CAFFE2_PERF_WITH_AVX512, +// CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the +// __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__, and __AVX__ flags the +// compiler provides. Note that we do not use the compiler flags but rely on +// the build system flags, because the common files (like foo.cc above) will +// always be built without __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__ +// and __AVX__. +// During run time: +// we use cpuinfo to identify cpu support and run the proper functions. + +#pragma once +#if defined(CAFFE2_PERF_WITH_SVE) || defined(CAFFE2_PERF_WITH_AVX512) || \ + defined(CAFFE2_PERF_WITH_AVX2) || defined(CAFFE2_PERF_WITH_AVX) +#include +#endif + +// DO macros: these should be used in your entry function, similar to foo() +// above, that routes implementations based on CPU capability. + +#define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__); + +#ifdef CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_arm_sve(); \ + if (isDo) { \ + return funcname##__sve(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_SVE + +#ifdef CAFFE2_PERF_WITH_AVX512 +#define AVX512_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && \ + cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512dq() && \ + cpuinfo_has_x86_avx512vl(); \ + if (isDo) { \ + return funcname##__avx512(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_AVX512 +#define AVX512_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_AVX512 + +#ifdef CAFFE2_PERF_WITH_AVX2 +#define AVX2_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx2(); \ + if (isDo) { \ + return funcname##__avx2(__VA_ARGS__); \ + } \ + } +#define AVX2_FMA_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx2() && \ + cpuinfo_has_x86_fma3(); \ + if (isDo) { \ + return funcname##__avx2_fma(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_AVX2 +#define AVX2_DO(funcname, ...) +#define AVX2_FMA_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_AVX2 + +#ifdef CAFFE2_PERF_WITH_AVX +#define AVX_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx(); \ + if (isDo) { \ + return funcname##__avx(__VA_ARGS__); \ + } \ + } +#define AVX_F16C_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx() && \ + cpuinfo_has_x86_f16c(); \ + if (isDo) { \ + return funcname##__avx_f16c(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_AVX +#define AVX_DO(funcname, ...) +#define AVX_F16C_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_AVX diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/embedding_lookup_idx.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/embedding_lookup_idx.h new file mode 100644 index 0000000000000000000000000000000000000000..67573fb21fa30bb6403e03b6ca13338c28c4fd7f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/embedding_lookup_idx.h @@ -0,0 +1,57 @@ +#pragma once + +#include + +namespace caffe2 { + +// clang-format off +/** + * Embedding lookup with reduction. + * + * `input` of size data_size * block_size + * `indices` of size index_size + * `offsets` of size output_size + * `weights` nullptr or array of size index_size + * `out` of size output_size * block_size + * + * Behavior is roughly equivalent to pseudocode: + * + * pos = 0 + * for (i = 0..output_size-1) + * for (k = 0..block_size-1) + * out[i*block_size + k] = 0 + * start_offset = offsets[i] + * end_offset = offsets[i+1] + * length = end_offset - start_offset + * for (j = start_offset..end_offset-1) + * for (k = 0..block_size-1) + * out[i*block_size + k] += input[indices[pos]*block_size + k] * + * (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0) + * pos += 1 + * if (normalize_weights && length > 0) + * for (k = 0..block_size-1) + * out[i*block_size + k] /= length + * + * TODO: make this API also take "offsets" rather than "lengths" to match the + * API for PyTorch's EmbeddingBag + */ +// clang-format on +template < + typename IndexType, + typename InType, + typename OutType, + bool IS_WEIGHT_POSITIONAL = false> +void EmbeddingLookupIdx( + const std::int64_t block_size, + const std::int64_t output_size, + const std::int64_t index_size, + const std::int64_t data_size, + const InType* input, + const IndexType* indices, + const IndexType* offsets, + const float* weights, // optional, can be null for non-weighted sum + const float* scale_bias, // optional scale & bias params for uint8 input + bool normalize_by_lengths, + OutType* out); + +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/crc_alt.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/crc_alt.h new file mode 100644 index 0000000000000000000000000000000000000000..9d1c4f1dc7ddc8997f7cc1297ef20d74de67afe0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/crc_alt.h @@ -0,0 +1,1343 @@ +#pragma once + +// ////////////////////////////////////////////////////////// +// Crc32.h +// Copyright (c) 2011-2019 Stephan Brumme. All rights reserved. +// Slicing-by-16 contributed by Bulat Ziganshin +// Tableless bytewise CRC contributed by Hagai Gold +// see http://create.stephan-brumme.com/disclaimer.html +// + +// if running on an embedded system, you might consider shrinking the +// big Crc32Lookup table by undefining these lines: +#define CRC32_USE_LOOKUP_TABLE_BYTE +#define CRC32_USE_LOOKUP_TABLE_SLICING_BY_4 +#define CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 +#define CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +// - crc32_bitwise doesn't need it at all +// - crc32_halfbyte has its own small lookup table +// - crc32_1byte_tableless and crc32_1byte_tableless2 don't need it at all +// - crc32_1byte needs only Crc32Lookup[0] +// - crc32_4bytes needs only Crc32Lookup[0..3] +// - crc32_8bytes needs only Crc32Lookup[0..7] +// - crc32_4x8bytes needs only Crc32Lookup[0..7] +// - crc32_16bytes needs all of Crc32Lookup +// using the aforementioned #defines the table is automatically fitted to your needs + +// uint8_t, uint32_t, int32_t +#include +// size_t +#include + +// crc32_fast selects the fastest algorithm depending on flags (CRC32_USE_LOOKUP_...) +/// compute CRC32 using the fastest algorithm for large datasets on modern CPUs +uint32_t crc32_fast (const void* data, size_t length, uint32_t previousCrc32 = 0); + +/// merge two CRC32 such that result = crc32(dataB, lengthB, crc32(dataA, lengthA)) +uint32_t crc32_combine (uint32_t crcA, uint32_t crcB, size_t lengthB); + +/// compute CRC32 (bitwise algorithm) +uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (half-byte algoritm) +uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0); + +#ifdef CRC32_USE_LOOKUP_TABLE_BYTE +/// compute CRC32 (standard algorithm) +uint32_t crc32_1byte (const void* data, size_t length, uint32_t previousCrc32 = 0); +#endif + +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless2(const void* data, size_t length, uint32_t previousCrc32 = 0); + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_4 +/// compute CRC32 (Slicing-by-4 algorithm) +uint32_t crc32_4bytes (const void* data, size_t length, uint32_t previousCrc32 = 0); +#endif + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 +/// compute CRC32 (Slicing-by-8 algorithm) +uint32_t crc32_8bytes (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (Slicing-by-8 algorithm), unroll inner loop 4 times +uint32_t crc32_4x8bytes(const void* data, size_t length, uint32_t previousCrc32 = 0); +#endif + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +/// compute CRC32 (Slicing-by-16 algorithm) +uint32_t crc32_16bytes (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (Slicing-by-16 algorithm, prefetch upcoming data blocks) +uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previousCrc32 = 0, size_t prefetchAhead = 256); +#endif + +// ////////////////////////////////////////////////////////// +// Crc32.cpp +// Copyright (c) 2011-2019 Stephan Brumme. All rights reserved. +// Slicing-by-16 contributed by Bulat Ziganshin +// Tableless bytewise CRC contributed by Hagai Gold +// see http://create.stephan-brumme.com/disclaimer.html +// + +// if running on an embedded system, you might consider shrinking the +// big Crc32Lookup table: +// - crc32_bitwise doesn't need it at all +// - crc32_halfbyte has its own small lookup table +// - crc32_1byte needs only Crc32Lookup[0] +// - crc32_4bytes needs only Crc32Lookup[0..3] +// - crc32_8bytes needs only Crc32Lookup[0..7] +// - crc32_4x8bytes needs only Crc32Lookup[0..7] +// - crc32_16bytes needs all of Crc32Lookup + + +#ifndef __LITTLE_ENDIAN + #define __LITTLE_ENDIAN 1234 +#endif +#ifndef __BIG_ENDIAN + #define __BIG_ENDIAN 4321 +#endif + +// define endianess and some integer data types +#if defined(_MSC_VER) || defined(__MINGW32__) + // Windows always little endian + #define __BYTE_ORDER __LITTLE_ENDIAN + + // intrinsics / prefetching + #if defined(_M_ARM64) + #include + #else + #include + #endif + + #ifdef __MINGW32__ + #define PREFETCH(location) __builtin_prefetch(location) + #else + #if defined(_M_ARM64) + #define PREFETCH(location) __prefetch(location) + #else + #define PREFETCH(location) _mm_prefetch(location, _MM_HINT_T0) + #endif + #endif +#elif defined(__APPLE__) + #include + #if TARGET_IPHONE_SIMULATOR + #define __BYTE_ORDER __LITTLE_ENDIAN + #elif TARGET_OS_IPHONE + #define __BYTE_ORDER __LITTLE_ENDIAN + #elif TARGET_OS_MAC + #include + #if defined(__BIG_ENDIAN__) + #define __BYTE_ORDER __BIG_ENDIAN + #endif + #if defined(__LITTLE_ENDIAN__) + #define __BYTE_ORDER __LITTLE_ENDIAN + #endif + #else + # error "Unknown Apple platform" + #endif +#elif defined(__ARMEB__) + #define __BYTE_ORDER __BIG_ENDIAN +#elif (defined(__BYTE_ORDER__) and !defined(__BYTE_ORDER)) + #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + #define __BYTE_ORDER __BIG_ENDIAN + #else + #define __BYTE_ORDER __LITTLE_ENDIAN + #endif +#else + // defines __BYTE_ORDER as __LITTLE_ENDIAN or __BIG_ENDIAN + #include +#endif + +// intrinsics / prefetching +#ifdef __GNUC__ + #define PREFETCH(location) __builtin_prefetch(location) +#else +#ifndef PREFETCH + // no prefetching + #define PREFETCH(location) ; +#endif +#endif + +// abort if byte order is undefined +#ifndef __BYTE_ORDER +#error undefined byte order, compile with -D__BYTE_ORDER=1234 (if little endian) or -D__BYTE_ORDER=4321 (big endian) +#endif + + +namespace +{ + /// zlib's CRC32 polynomial + const uint32_t Polynomial = 0xEDB88320; + + /// swap endianess + static inline uint32_t swap(uint32_t x) + { + #if defined(__GNUC__) || defined(__clang__) + return __builtin_bswap32(x); + #else + return (x >> 24) | + ((x >> 8) & 0x0000FF00) | + ((x << 8) & 0x00FF0000) | + (x << 24); + #endif + } + + /// Slicing-By-16 + #ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 + const size_t MaxSlice = 16; + #elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) + const size_t MaxSlice = 8; + #elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) + const size_t MaxSlice = 4; + #elif defined(CRC32_USE_LOOKUP_TABLE_BYTE) + const size_t MaxSlice = 1; + #else + #define NO_LUT // don't need Crc32Lookup at all + #endif + +} // anonymous namespace + +#ifndef NO_LUT +/// forward declaration, table is at the end of this file +extern const uint32_t Crc32Lookup[MaxSlice][256]; // extern is needed to keep compiler happy +#endif + + +/// compute CRC32 (bitwise algorithm) +uint32_t crc32_bitwise(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + { + crc ^= *current++; + + for (int j = 0; j < 8; j++) + { + // branch-free + crc = (crc >> 1) ^ (-int32_t(crc & 1) & Polynomial); + + // branching, much slower: + //if (crc & 1) + // crc = (crc >> 1) ^ Polynomial; + //else + // crc = crc >> 1; + } + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (half-byte algoritm) +uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + /// look-up table for half-byte, same as crc32Lookup[0][16*i] + static const uint32_t Crc32Lookup16[16] = + { + 0x00000000,0x1DB71064,0x3B6E20C8,0x26D930AC,0x76DC4190,0x6B6B51F4,0x4DB26158,0x5005713C, + 0xEDB88320,0xF00F9344,0xD6D6A3E8,0xCB61B38C,0x9B64C2B0,0x86D3D2D4,0xA00AE278,0xBDBDF21C + }; + + while (length-- != 0) + { + crc = Crc32Lookup16[(crc ^ *current ) & 0x0F] ^ (crc >> 4); + crc = Crc32Lookup16[(crc ^ (*current >> 4)) & 0x0F] ^ (crc >> 4); + current++; + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +#ifdef CRC32_USE_LOOKUP_TABLE_BYTE +/// compute CRC32 (standard algorithm) +uint32_t crc32_1byte(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *current++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif + + +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + { + uint8_t s = uint8_t(crc) ^ *current++; + + // Hagai Gold made me aware of this table-less algorithm and send me code + + // polynomial 0xEDB88320 can be written in binary as 11101101101110001000001100100000b + // reverse the bits (or just assume bit 0 is the first one) + // and we have bits set at position 0, 1, 2, 4, 5, 7, 8, 10, 11, 12, 16, 22, 23, 26 + // => those are the shift offsets: + //crc = (crc >> 8) ^ + // t ^ + // (t >> 1) ^ (t >> 2) ^ (t >> 4) ^ (t >> 5) ^ // == y + // (t >> 7) ^ (t >> 8) ^ (t >> 10) ^ (t >> 11) ^ // == y >> 6 + // (t >> 12) ^ (t >> 16) ^ // == z + // (t >> 22) ^ (t >> 26) ^ // == z >> 10 + // (t >> 23); + + // the fastest I can come up with: + uint32_t low = (s ^ (s << 6)) & 0xFF; + uint32_t a = (low * ((1 << 23) + (1 << 14) + (1 << 2))); + crc = (crc >> 8) ^ + (low * ((1 << 24) + (1 << 16) + (1 << 8))) ^ + a ^ + (a >> 1) ^ + (low * ((1 << 20) + (1 << 12) )) ^ + (low << 19) ^ + (low << 17) ^ + (low >> 2); + + // Hagai's code: + /*uint32_t t = (s ^ (s << 6)) << 24; + // some temporaries to optimize XOR + uint32_t x = (t >> 1) ^ (t >> 2); + uint32_t y = x ^ (x >> 3); + uint32_t z = (t >> 12) ^ (t >> 16); + crc = (crc >> 8) ^ + t ^ (t >> 23) ^ + y ^ (y >> 6) ^ + z ^ (z >> 10);*/ + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless2(const void* data, size_t length, uint32_t previousCrc32) +{ + int32_t crc = ~previousCrc32; // note: signed integer, right shift distributes sign bit into lower bits + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + { + crc = crc ^ *current++; + + uint32_t c = (((crc << 31) >> 31) & ((Polynomial >> 7) ^ (Polynomial >> 1))) ^ + (((crc << 30) >> 31) & ((Polynomial >> 6) ^ Polynomial)) ^ + (((crc << 29) >> 31) & (Polynomial >> 5)) ^ + (((crc << 28) >> 31) & (Polynomial >> 4)) ^ + (((crc << 27) >> 31) & (Polynomial >> 3)) ^ + (((crc << 26) >> 31) & (Polynomial >> 2)) ^ + (((crc << 25) >> 31) & (Polynomial >> 1)) ^ + (((crc << 24) >> 31) & Polynomial); + + crc = ((uint32_t)crc >> 8) ^ c; // convert to unsigned integer before right shift + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_4 +/// compute CRC32 (Slicing-by-4 algorithm) +uint32_t crc32_4bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // process four bytes at once (Slicing-by-4) + while (length >= 4) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + crc = Crc32Lookup[0][ one & 0xFF] ^ + Crc32Lookup[1][(one>> 8) & 0xFF] ^ + Crc32Lookup[2][(one>>16) & 0xFF] ^ + Crc32Lookup[3][(one>>24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + crc = Crc32Lookup[0][(one>>24) & 0xFF] ^ + Crc32Lookup[1][(one>>16) & 0xFF] ^ + Crc32Lookup[2][(one>> 8) & 0xFF] ^ + Crc32Lookup[3][ one & 0xFF]; +#endif + + length -= 4; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 3 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif + + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 +/// compute CRC32 (Slicing-by-8 algorithm) +uint32_t crc32_8bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // process eight bytes at once (Slicing-by-8) + while (length >= 8) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + crc = Crc32Lookup[0][ two & 0xFF] ^ + Crc32Lookup[1][(two>> 8) & 0xFF] ^ + Crc32Lookup[2][(two>>16) & 0xFF] ^ + Crc32Lookup[3][(two>>24) & 0xFF] ^ + Crc32Lookup[4][ one & 0xFF] ^ + Crc32Lookup[5][(one>> 8) & 0xFF] ^ + Crc32Lookup[6][(one>>16) & 0xFF] ^ + Crc32Lookup[7][(one>>24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + crc = Crc32Lookup[0][(two>>24) & 0xFF] ^ + Crc32Lookup[1][(two>>16) & 0xFF] ^ + Crc32Lookup[2][(two>> 8) & 0xFF] ^ + Crc32Lookup[3][ two & 0xFF] ^ + Crc32Lookup[4][(one>>24) & 0xFF] ^ + Crc32Lookup[5][(one>>16) & 0xFF] ^ + Crc32Lookup[6][(one>> 8) & 0xFF] ^ + Crc32Lookup[7][ one & 0xFF]; +#endif + + length -= 8; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 7 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (Slicing-by-8 algorithm), unroll inner loop 4 times +uint32_t crc32_4x8bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // enabling optimization (at least -O2) automatically unrolls the inner for-loop + const size_t Unroll = 4; + const size_t BytesAtOnce = 8 * Unroll; + + // process 4x eight bytes at once (Slicing-by-8) + while (length >= BytesAtOnce) + { + for (size_t unrolling = 0; unrolling < Unroll; unrolling++) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + crc = Crc32Lookup[0][ two & 0xFF] ^ + Crc32Lookup[1][(two>> 8) & 0xFF] ^ + Crc32Lookup[2][(two>>16) & 0xFF] ^ + Crc32Lookup[3][(two>>24) & 0xFF] ^ + Crc32Lookup[4][ one & 0xFF] ^ + Crc32Lookup[5][(one>> 8) & 0xFF] ^ + Crc32Lookup[6][(one>>16) & 0xFF] ^ + Crc32Lookup[7][(one>>24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + crc = Crc32Lookup[0][(two>>24) & 0xFF] ^ + Crc32Lookup[1][(two>>16) & 0xFF] ^ + Crc32Lookup[2][(two>> 8) & 0xFF] ^ + Crc32Lookup[3][ two & 0xFF] ^ + Crc32Lookup[4][(one>>24) & 0xFF] ^ + Crc32Lookup[5][(one>>16) & 0xFF] ^ + Crc32Lookup[6][(one>> 8) & 0xFF] ^ + Crc32Lookup[7][ one & 0xFF]; +#endif + + } + + length -= BytesAtOnce; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 31 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif // CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 + + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +/// compute CRC32 (Slicing-by-16 algorithm) +uint32_t crc32_16bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // enabling optimization (at least -O2) automatically unrolls the inner for-loop + const size_t Unroll = 4; + const size_t BytesAtOnce = 16 * Unroll; + + while (length >= BytesAtOnce) + { + for (size_t unrolling = 0; unrolling < Unroll; unrolling++) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][ four & 0xFF] ^ + Crc32Lookup[ 1][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 3][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 4][ three & 0xFF] ^ + Crc32Lookup[ 5][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 7][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 8][ two & 0xFF] ^ + Crc32Lookup[ 9][(two >> 8) & 0xFF] ^ + Crc32Lookup[10][(two >> 16) & 0xFF] ^ + Crc32Lookup[11][(two >> 24) & 0xFF] ^ + Crc32Lookup[12][ one & 0xFF] ^ + Crc32Lookup[13][(one >> 8) & 0xFF] ^ + Crc32Lookup[14][(one >> 16) & 0xFF] ^ + Crc32Lookup[15][(one >> 24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 1][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 3][ four & 0xFF] ^ + Crc32Lookup[ 4][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 5][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 7][ three & 0xFF] ^ + Crc32Lookup[ 8][(two >> 24) & 0xFF] ^ + Crc32Lookup[ 9][(two >> 16) & 0xFF] ^ + Crc32Lookup[10][(two >> 8) & 0xFF] ^ + Crc32Lookup[11][ two & 0xFF] ^ + Crc32Lookup[12][(one >> 24) & 0xFF] ^ + Crc32Lookup[13][(one >> 16) & 0xFF] ^ + Crc32Lookup[14][(one >> 8) & 0xFF] ^ + Crc32Lookup[15][ one & 0xFF]; +#endif + } + + length -= BytesAtOnce; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 63 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (Slicing-by-16 algorithm, prefetch upcoming data blocks) +uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previousCrc32, size_t prefetchAhead) +{ + // CRC code is identical to crc32_16bytes (including unrolling), only added prefetching + // 256 bytes look-ahead seems to be the sweet spot on Core i7 CPUs + + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // enabling optimization (at least -O2) automatically unrolls the for-loop + const size_t Unroll = 4; + const size_t BytesAtOnce = 16 * Unroll; + + while (length >= BytesAtOnce + prefetchAhead) + { + PREFETCH(((const char*) current) + prefetchAhead); + + for (size_t unrolling = 0; unrolling < Unroll; unrolling++) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][ four & 0xFF] ^ + Crc32Lookup[ 1][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 3][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 4][ three & 0xFF] ^ + Crc32Lookup[ 5][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 7][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 8][ two & 0xFF] ^ + Crc32Lookup[ 9][(two >> 8) & 0xFF] ^ + Crc32Lookup[10][(two >> 16) & 0xFF] ^ + Crc32Lookup[11][(two >> 24) & 0xFF] ^ + Crc32Lookup[12][ one & 0xFF] ^ + Crc32Lookup[13][(one >> 8) & 0xFF] ^ + Crc32Lookup[14][(one >> 16) & 0xFF] ^ + Crc32Lookup[15][(one >> 24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 1][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 3][ four & 0xFF] ^ + Crc32Lookup[ 4][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 5][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 7][ three & 0xFF] ^ + Crc32Lookup[ 8][(two >> 24) & 0xFF] ^ + Crc32Lookup[ 9][(two >> 16) & 0xFF] ^ + Crc32Lookup[10][(two >> 8) & 0xFF] ^ + Crc32Lookup[11][ two & 0xFF] ^ + Crc32Lookup[12][(one >> 24) & 0xFF] ^ + Crc32Lookup[13][(one >> 16) & 0xFF] ^ + Crc32Lookup[14][(one >> 8) & 0xFF] ^ + Crc32Lookup[15][ one & 0xFF]; +#endif + } + + length -= BytesAtOnce; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 63 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif + + +/// compute CRC32 using the fastest algorithm for large datasets on modern CPUs +uint32_t crc32_fast(const void* data, size_t length, uint32_t previousCrc32) +{ +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 + return crc32_16bytes (data, length, previousCrc32); +#elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) + return crc32_8bytes (data, length, previousCrc32); +#elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) + return crc32_4bytes (data, length, previousCrc32); +#elif defined(CRC32_USE_LOOKUP_TABLE_BYTE) + return crc32_1byte (data, length, previousCrc32); +#else + return crc32_halfbyte(data, length, previousCrc32); +#endif +} + + +/// merge two CRC32 such that result = crc32(dataB, lengthB, crc32(dataA, lengthA)) +uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) +{ + // based on Mark Adler's crc_combine from + // https://github.com/madler/pigz/blob/master/pigz.c + + // main idea: + // - if you have two equally-sized blocks A and B, + // then you can create a block C = A ^ B + // which has the property crc(C) = crc(A) ^ crc(B) + // - if you append length(B) zeros to A and call it A' (think of it as AAAA000) + // and prepend length(A) zeros to B and call it B' (think of it as 0000BBB) + // then exists a C' = A' ^ B' + // - remember: if you XOR someting with zero, it remains unchanged: X ^ 0 = X + // - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B') + // - the trick is to compute crc(A') based on crc(A) + // and crc(B') based on crc(B) + // - since B' starts with many zeros, the crc of those initial zeros is still zero + // - that means crc(B') = crc(B) + // - unfortunately the trailing zeros of A' change the crc, so usually crc(A') != crc(A) + // - the following code is a fast algorithm to compute crc(A') + // - starting with crc(A) and appending length(B) zeros, needing just log2(length(B)) iterations + // - the details are explained by the original author at + // https://stackoverflow.com/questions/23122312/crc-calculation-of-a-mostly-static-data-stream/23126768 + // + // notes: + // - I squeezed everything into one function to keep global namespace clean (original code two helper functions) + // - most original comments are still in place, I added comments where these helper functions where made inline code + // - performance-wise there isn't any differenze to the original zlib/pigz code + + // degenerated case + if (lengthB == 0) + return crcA; + + /// CRC32 => 32 bits + const uint32_t CrcBits = 32; + + uint32_t odd [CrcBits]; // odd-power-of-two zeros operator + uint32_t even[CrcBits]; // even-power-of-two zeros operator + + // put operator for one zero bit in odd + odd[0] = Polynomial; // CRC-32 polynomial + for (uint32_t i = 1; i < CrcBits; i++) + odd[i] = 1 << (i - 1); + + // put operator for two zero bits in even + // same as gf2_matrix_square(even, odd); + for (uint32_t i = 0; i < CrcBits; i++) + { + uint32_t vec = odd[i]; + even[i] = 0; + for (int j = 0; vec != 0; j++, vec >>= 1) + if (vec & 1) + even[i] ^= odd[j]; + } + // put operator for four zero bits in odd + // same as gf2_matrix_square(odd, even); + for (uint32_t i = 0; i < CrcBits; i++) + { + uint32_t vec = even[i]; + odd[i] = 0; + for (int j = 0; vec != 0; j++, vec >>= 1) + if (vec & 1) + odd[i] ^= even[j]; + } + + // the following loop becomes much shorter if I keep swapping even and odd + uint32_t* a = even; + uint32_t* b = odd; + // apply secondLength zeros to firstCrc32 + for (; lengthB > 0; lengthB >>= 1) + { + // same as gf2_matrix_square(a, b); + for (uint32_t i = 0; i < CrcBits; i++) + { + uint32_t vec = b[i]; + a[i] = 0; + for (int j = 0; vec != 0; j++, vec >>= 1) + if (vec & 1) + a[i] ^= b[j]; + } + + // apply zeros operator for this bit + if (lengthB & 1) + { + // same as firstCrc32 = gf2_matrix_times(a, firstCrc32); + uint32_t sum = 0; + for (int i = 0; crcA != 0; i++, crcA >>= 1) + if (crcA & 1) + sum ^= a[i]; + crcA = sum; + } + + // switch even and odd + uint32_t* t = a; a = b; b = t; + } + + // return combined crc + return crcA ^ crcB; +} + + +// ////////////////////////////////////////////////////////// +// constants + + +#ifndef NO_LUT +/// look-up table, already declared above +const uint32_t Crc32Lookup[MaxSlice][256] = +{ + //// same algorithm as crc32_bitwise + //for (int i = 0; i <= 0xFF; i++) + //{ + // uint32_t crc = i; + // for (int j = 0; j < 8; j++) + // crc = (crc >> 1) ^ ((crc & 1) * Polynomial); + // Crc32Lookup[0][i] = crc; + //} + //// ... and the following slicing-by-8 algorithm (from Intel): + //// http://www.intel.com/technology/comms/perfnet/download/CRC_generators.pdf + //// http://sourceforge.net/projects/slicing-by-8/ + //for (int slice = 1; slice < MaxSlice; slice++) + // Crc32Lookup[slice][i] = (Crc32Lookup[slice - 1][i] >> 8) ^ Crc32Lookup[0][Crc32Lookup[slice - 1][i] & 0xFF]; + { + // note: the first number of every second row corresponds to the half-byte look-up table ! + 0x00000000,0x77073096,0xEE0E612C,0x990951BA,0x076DC419,0x706AF48F,0xE963A535,0x9E6495A3, + 0x0EDB8832,0x79DCB8A4,0xE0D5E91E,0x97D2D988,0x09B64C2B,0x7EB17CBD,0xE7B82D07,0x90BF1D91, + 0x1DB71064,0x6AB020F2,0xF3B97148,0x84BE41DE,0x1ADAD47D,0x6DDDE4EB,0xF4D4B551,0x83D385C7, + 0x136C9856,0x646BA8C0,0xFD62F97A,0x8A65C9EC,0x14015C4F,0x63066CD9,0xFA0F3D63,0x8D080DF5, + 0x3B6E20C8,0x4C69105E,0xD56041E4,0xA2677172,0x3C03E4D1,0x4B04D447,0xD20D85FD,0xA50AB56B, + 0x35B5A8FA,0x42B2986C,0xDBBBC9D6,0xACBCF940,0x32D86CE3,0x45DF5C75,0xDCD60DCF,0xABD13D59, + 0x26D930AC,0x51DE003A,0xC8D75180,0xBFD06116,0x21B4F4B5,0x56B3C423,0xCFBA9599,0xB8BDA50F, + 0x2802B89E,0x5F058808,0xC60CD9B2,0xB10BE924,0x2F6F7C87,0x58684C11,0xC1611DAB,0xB6662D3D, + 0x76DC4190,0x01DB7106,0x98D220BC,0xEFD5102A,0x71B18589,0x06B6B51F,0x9FBFE4A5,0xE8B8D433, + 0x7807C9A2,0x0F00F934,0x9609A88E,0xE10E9818,0x7F6A0DBB,0x086D3D2D,0x91646C97,0xE6635C01, + 0x6B6B51F4,0x1C6C6162,0x856530D8,0xF262004E,0x6C0695ED,0x1B01A57B,0x8208F4C1,0xF50FC457, + 0x65B0D9C6,0x12B7E950,0x8BBEB8EA,0xFCB9887C,0x62DD1DDF,0x15DA2D49,0x8CD37CF3,0xFBD44C65, + 0x4DB26158,0x3AB551CE,0xA3BC0074,0xD4BB30E2,0x4ADFA541,0x3DD895D7,0xA4D1C46D,0xD3D6F4FB, + 0x4369E96A,0x346ED9FC,0xAD678846,0xDA60B8D0,0x44042D73,0x33031DE5,0xAA0A4C5F,0xDD0D7CC9, + 0x5005713C,0x270241AA,0xBE0B1010,0xC90C2086,0x5768B525,0x206F85B3,0xB966D409,0xCE61E49F, + 0x5EDEF90E,0x29D9C998,0xB0D09822,0xC7D7A8B4,0x59B33D17,0x2EB40D81,0xB7BD5C3B,0xC0BA6CAD, + 0xEDB88320,0x9ABFB3B6,0x03B6E20C,0x74B1D29A,0xEAD54739,0x9DD277AF,0x04DB2615,0x73DC1683, + 0xE3630B12,0x94643B84,0x0D6D6A3E,0x7A6A5AA8,0xE40ECF0B,0x9309FF9D,0x0A00AE27,0x7D079EB1, + 0xF00F9344,0x8708A3D2,0x1E01F268,0x6906C2FE,0xF762575D,0x806567CB,0x196C3671,0x6E6B06E7, + 0xFED41B76,0x89D32BE0,0x10DA7A5A,0x67DD4ACC,0xF9B9DF6F,0x8EBEEFF9,0x17B7BE43,0x60B08ED5, + 0xD6D6A3E8,0xA1D1937E,0x38D8C2C4,0x4FDFF252,0xD1BB67F1,0xA6BC5767,0x3FB506DD,0x48B2364B, + 0xD80D2BDA,0xAF0A1B4C,0x36034AF6,0x41047A60,0xDF60EFC3,0xA867DF55,0x316E8EEF,0x4669BE79, + 0xCB61B38C,0xBC66831A,0x256FD2A0,0x5268E236,0xCC0C7795,0xBB0B4703,0x220216B9,0x5505262F, + 0xC5BA3BBE,0xB2BD0B28,0x2BB45A92,0x5CB36A04,0xC2D7FFA7,0xB5D0CF31,0x2CD99E8B,0x5BDEAE1D, + 0x9B64C2B0,0xEC63F226,0x756AA39C,0x026D930A,0x9C0906A9,0xEB0E363F,0x72076785,0x05005713, + 0x95BF4A82,0xE2B87A14,0x7BB12BAE,0x0CB61B38,0x92D28E9B,0xE5D5BE0D,0x7CDCEFB7,0x0BDBDF21, + 0x86D3D2D4,0xF1D4E242,0x68DDB3F8,0x1FDA836E,0x81BE16CD,0xF6B9265B,0x6FB077E1,0x18B74777, + 0x88085AE6,0xFF0F6A70,0x66063BCA,0x11010B5C,0x8F659EFF,0xF862AE69,0x616BFFD3,0x166CCF45, + 0xA00AE278,0xD70DD2EE,0x4E048354,0x3903B3C2,0xA7672661,0xD06016F7,0x4969474D,0x3E6E77DB, + 0xAED16A4A,0xD9D65ADC,0x40DF0B66,0x37D83BF0,0xA9BCAE53,0xDEBB9EC5,0x47B2CF7F,0x30B5FFE9, + 0xBDBDF21C,0xCABAC28A,0x53B39330,0x24B4A3A6,0xBAD03605,0xCDD70693,0x54DE5729,0x23D967BF, + 0xB3667A2E,0xC4614AB8,0x5D681B02,0x2A6F2B94,0xB40BBE37,0xC30C8EA1,0x5A05DF1B,0x2D02EF8D, + } + +#if defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_16) + // beyond this point only relevant for Slicing-by-4, Slicing-by-8 and Slicing-by-16 + ,{ + 0x00000000,0x191B3141,0x32366282,0x2B2D53C3,0x646CC504,0x7D77F445,0x565AA786,0x4F4196C7, + 0xC8D98A08,0xD1C2BB49,0xFAEFE88A,0xE3F4D9CB,0xACB54F0C,0xB5AE7E4D,0x9E832D8E,0x87981CCF, + 0x4AC21251,0x53D92310,0x78F470D3,0x61EF4192,0x2EAED755,0x37B5E614,0x1C98B5D7,0x05838496, + 0x821B9859,0x9B00A918,0xB02DFADB,0xA936CB9A,0xE6775D5D,0xFF6C6C1C,0xD4413FDF,0xCD5A0E9E, + 0x958424A2,0x8C9F15E3,0xA7B24620,0xBEA97761,0xF1E8E1A6,0xE8F3D0E7,0xC3DE8324,0xDAC5B265, + 0x5D5DAEAA,0x44469FEB,0x6F6BCC28,0x7670FD69,0x39316BAE,0x202A5AEF,0x0B07092C,0x121C386D, + 0xDF4636F3,0xC65D07B2,0xED705471,0xF46B6530,0xBB2AF3F7,0xA231C2B6,0x891C9175,0x9007A034, + 0x179FBCFB,0x0E848DBA,0x25A9DE79,0x3CB2EF38,0x73F379FF,0x6AE848BE,0x41C51B7D,0x58DE2A3C, + 0xF0794F05,0xE9627E44,0xC24F2D87,0xDB541CC6,0x94158A01,0x8D0EBB40,0xA623E883,0xBF38D9C2, + 0x38A0C50D,0x21BBF44C,0x0A96A78F,0x138D96CE,0x5CCC0009,0x45D73148,0x6EFA628B,0x77E153CA, + 0xBABB5D54,0xA3A06C15,0x888D3FD6,0x91960E97,0xDED79850,0xC7CCA911,0xECE1FAD2,0xF5FACB93, + 0x7262D75C,0x6B79E61D,0x4054B5DE,0x594F849F,0x160E1258,0x0F152319,0x243870DA,0x3D23419B, + 0x65FD6BA7,0x7CE65AE6,0x57CB0925,0x4ED03864,0x0191AEA3,0x188A9FE2,0x33A7CC21,0x2ABCFD60, + 0xAD24E1AF,0xB43FD0EE,0x9F12832D,0x8609B26C,0xC94824AB,0xD05315EA,0xFB7E4629,0xE2657768, + 0x2F3F79F6,0x362448B7,0x1D091B74,0x04122A35,0x4B53BCF2,0x52488DB3,0x7965DE70,0x607EEF31, + 0xE7E6F3FE,0xFEFDC2BF,0xD5D0917C,0xCCCBA03D,0x838A36FA,0x9A9107BB,0xB1BC5478,0xA8A76539, + 0x3B83984B,0x2298A90A,0x09B5FAC9,0x10AECB88,0x5FEF5D4F,0x46F46C0E,0x6DD93FCD,0x74C20E8C, + 0xF35A1243,0xEA412302,0xC16C70C1,0xD8774180,0x9736D747,0x8E2DE606,0xA500B5C5,0xBC1B8484, + 0x71418A1A,0x685ABB5B,0x4377E898,0x5A6CD9D9,0x152D4F1E,0x0C367E5F,0x271B2D9C,0x3E001CDD, + 0xB9980012,0xA0833153,0x8BAE6290,0x92B553D1,0xDDF4C516,0xC4EFF457,0xEFC2A794,0xF6D996D5, + 0xAE07BCE9,0xB71C8DA8,0x9C31DE6B,0x852AEF2A,0xCA6B79ED,0xD37048AC,0xF85D1B6F,0xE1462A2E, + 0x66DE36E1,0x7FC507A0,0x54E85463,0x4DF36522,0x02B2F3E5,0x1BA9C2A4,0x30849167,0x299FA026, + 0xE4C5AEB8,0xFDDE9FF9,0xD6F3CC3A,0xCFE8FD7B,0x80A96BBC,0x99B25AFD,0xB29F093E,0xAB84387F, + 0x2C1C24B0,0x350715F1,0x1E2A4632,0x07317773,0x4870E1B4,0x516BD0F5,0x7A468336,0x635DB277, + 0xCBFAD74E,0xD2E1E60F,0xF9CCB5CC,0xE0D7848D,0xAF96124A,0xB68D230B,0x9DA070C8,0x84BB4189, + 0x03235D46,0x1A386C07,0x31153FC4,0x280E0E85,0x674F9842,0x7E54A903,0x5579FAC0,0x4C62CB81, + 0x8138C51F,0x9823F45E,0xB30EA79D,0xAA1596DC,0xE554001B,0xFC4F315A,0xD7626299,0xCE7953D8, + 0x49E14F17,0x50FA7E56,0x7BD72D95,0x62CC1CD4,0x2D8D8A13,0x3496BB52,0x1FBBE891,0x06A0D9D0, + 0x5E7EF3EC,0x4765C2AD,0x6C48916E,0x7553A02F,0x3A1236E8,0x230907A9,0x0824546A,0x113F652B, + 0x96A779E4,0x8FBC48A5,0xA4911B66,0xBD8A2A27,0xF2CBBCE0,0xEBD08DA1,0xC0FDDE62,0xD9E6EF23, + 0x14BCE1BD,0x0DA7D0FC,0x268A833F,0x3F91B27E,0x70D024B9,0x69CB15F8,0x42E6463B,0x5BFD777A, + 0xDC656BB5,0xC57E5AF4,0xEE530937,0xF7483876,0xB809AEB1,0xA1129FF0,0x8A3FCC33,0x9324FD72, + }, + + { + 0x00000000,0x01C26A37,0x0384D46E,0x0246BE59,0x0709A8DC,0x06CBC2EB,0x048D7CB2,0x054F1685, + 0x0E1351B8,0x0FD13B8F,0x0D9785D6,0x0C55EFE1,0x091AF964,0x08D89353,0x0A9E2D0A,0x0B5C473D, + 0x1C26A370,0x1DE4C947,0x1FA2771E,0x1E601D29,0x1B2F0BAC,0x1AED619B,0x18ABDFC2,0x1969B5F5, + 0x1235F2C8,0x13F798FF,0x11B126A6,0x10734C91,0x153C5A14,0x14FE3023,0x16B88E7A,0x177AE44D, + 0x384D46E0,0x398F2CD7,0x3BC9928E,0x3A0BF8B9,0x3F44EE3C,0x3E86840B,0x3CC03A52,0x3D025065, + 0x365E1758,0x379C7D6F,0x35DAC336,0x3418A901,0x3157BF84,0x3095D5B3,0x32D36BEA,0x331101DD, + 0x246BE590,0x25A98FA7,0x27EF31FE,0x262D5BC9,0x23624D4C,0x22A0277B,0x20E69922,0x2124F315, + 0x2A78B428,0x2BBADE1F,0x29FC6046,0x283E0A71,0x2D711CF4,0x2CB376C3,0x2EF5C89A,0x2F37A2AD, + 0x709A8DC0,0x7158E7F7,0x731E59AE,0x72DC3399,0x7793251C,0x76514F2B,0x7417F172,0x75D59B45, + 0x7E89DC78,0x7F4BB64F,0x7D0D0816,0x7CCF6221,0x798074A4,0x78421E93,0x7A04A0CA,0x7BC6CAFD, + 0x6CBC2EB0,0x6D7E4487,0x6F38FADE,0x6EFA90E9,0x6BB5866C,0x6A77EC5B,0x68315202,0x69F33835, + 0x62AF7F08,0x636D153F,0x612BAB66,0x60E9C151,0x65A6D7D4,0x6464BDE3,0x662203BA,0x67E0698D, + 0x48D7CB20,0x4915A117,0x4B531F4E,0x4A917579,0x4FDE63FC,0x4E1C09CB,0x4C5AB792,0x4D98DDA5, + 0x46C49A98,0x4706F0AF,0x45404EF6,0x448224C1,0x41CD3244,0x400F5873,0x4249E62A,0x438B8C1D, + 0x54F16850,0x55330267,0x5775BC3E,0x56B7D609,0x53F8C08C,0x523AAABB,0x507C14E2,0x51BE7ED5, + 0x5AE239E8,0x5B2053DF,0x5966ED86,0x58A487B1,0x5DEB9134,0x5C29FB03,0x5E6F455A,0x5FAD2F6D, + 0xE1351B80,0xE0F771B7,0xE2B1CFEE,0xE373A5D9,0xE63CB35C,0xE7FED96B,0xE5B86732,0xE47A0D05, + 0xEF264A38,0xEEE4200F,0xECA29E56,0xED60F461,0xE82FE2E4,0xE9ED88D3,0xEBAB368A,0xEA695CBD, + 0xFD13B8F0,0xFCD1D2C7,0xFE976C9E,0xFF5506A9,0xFA1A102C,0xFBD87A1B,0xF99EC442,0xF85CAE75, + 0xF300E948,0xF2C2837F,0xF0843D26,0xF1465711,0xF4094194,0xF5CB2BA3,0xF78D95FA,0xF64FFFCD, + 0xD9785D60,0xD8BA3757,0xDAFC890E,0xDB3EE339,0xDE71F5BC,0xDFB39F8B,0xDDF521D2,0xDC374BE5, + 0xD76B0CD8,0xD6A966EF,0xD4EFD8B6,0xD52DB281,0xD062A404,0xD1A0CE33,0xD3E6706A,0xD2241A5D, + 0xC55EFE10,0xC49C9427,0xC6DA2A7E,0xC7184049,0xC25756CC,0xC3953CFB,0xC1D382A2,0xC011E895, + 0xCB4DAFA8,0xCA8FC59F,0xC8C97BC6,0xC90B11F1,0xCC440774,0xCD866D43,0xCFC0D31A,0xCE02B92D, + 0x91AF9640,0x906DFC77,0x922B422E,0x93E92819,0x96A63E9C,0x976454AB,0x9522EAF2,0x94E080C5, + 0x9FBCC7F8,0x9E7EADCF,0x9C381396,0x9DFA79A1,0x98B56F24,0x99770513,0x9B31BB4A,0x9AF3D17D, + 0x8D893530,0x8C4B5F07,0x8E0DE15E,0x8FCF8B69,0x8A809DEC,0x8B42F7DB,0x89044982,0x88C623B5, + 0x839A6488,0x82580EBF,0x801EB0E6,0x81DCDAD1,0x8493CC54,0x8551A663,0x8717183A,0x86D5720D, + 0xA9E2D0A0,0xA820BA97,0xAA6604CE,0xABA46EF9,0xAEEB787C,0xAF29124B,0xAD6FAC12,0xACADC625, + 0xA7F18118,0xA633EB2F,0xA4755576,0xA5B73F41,0xA0F829C4,0xA13A43F3,0xA37CFDAA,0xA2BE979D, + 0xB5C473D0,0xB40619E7,0xB640A7BE,0xB782CD89,0xB2CDDB0C,0xB30FB13B,0xB1490F62,0xB08B6555, + 0xBBD72268,0xBA15485F,0xB853F606,0xB9919C31,0xBCDE8AB4,0xBD1CE083,0xBF5A5EDA,0xBE9834ED, + }, + + { + 0x00000000,0xB8BC6765,0xAA09C88B,0x12B5AFEE,0x8F629757,0x37DEF032,0x256B5FDC,0x9DD738B9, + 0xC5B428EF,0x7D084F8A,0x6FBDE064,0xD7018701,0x4AD6BFB8,0xF26AD8DD,0xE0DF7733,0x58631056, + 0x5019579F,0xE8A530FA,0xFA109F14,0x42ACF871,0xDF7BC0C8,0x67C7A7AD,0x75720843,0xCDCE6F26, + 0x95AD7F70,0x2D111815,0x3FA4B7FB,0x8718D09E,0x1ACFE827,0xA2738F42,0xB0C620AC,0x087A47C9, + 0xA032AF3E,0x188EC85B,0x0A3B67B5,0xB28700D0,0x2F503869,0x97EC5F0C,0x8559F0E2,0x3DE59787, + 0x658687D1,0xDD3AE0B4,0xCF8F4F5A,0x7733283F,0xEAE41086,0x525877E3,0x40EDD80D,0xF851BF68, + 0xF02BF8A1,0x48979FC4,0x5A22302A,0xE29E574F,0x7F496FF6,0xC7F50893,0xD540A77D,0x6DFCC018, + 0x359FD04E,0x8D23B72B,0x9F9618C5,0x272A7FA0,0xBAFD4719,0x0241207C,0x10F48F92,0xA848E8F7, + 0x9B14583D,0x23A83F58,0x311D90B6,0x89A1F7D3,0x1476CF6A,0xACCAA80F,0xBE7F07E1,0x06C36084, + 0x5EA070D2,0xE61C17B7,0xF4A9B859,0x4C15DF3C,0xD1C2E785,0x697E80E0,0x7BCB2F0E,0xC377486B, + 0xCB0D0FA2,0x73B168C7,0x6104C729,0xD9B8A04C,0x446F98F5,0xFCD3FF90,0xEE66507E,0x56DA371B, + 0x0EB9274D,0xB6054028,0xA4B0EFC6,0x1C0C88A3,0x81DBB01A,0x3967D77F,0x2BD27891,0x936E1FF4, + 0x3B26F703,0x839A9066,0x912F3F88,0x299358ED,0xB4446054,0x0CF80731,0x1E4DA8DF,0xA6F1CFBA, + 0xFE92DFEC,0x462EB889,0x549B1767,0xEC277002,0x71F048BB,0xC94C2FDE,0xDBF98030,0x6345E755, + 0x6B3FA09C,0xD383C7F9,0xC1366817,0x798A0F72,0xE45D37CB,0x5CE150AE,0x4E54FF40,0xF6E89825, + 0xAE8B8873,0x1637EF16,0x048240F8,0xBC3E279D,0x21E91F24,0x99557841,0x8BE0D7AF,0x335CB0CA, + 0xED59B63B,0x55E5D15E,0x47507EB0,0xFFEC19D5,0x623B216C,0xDA874609,0xC832E9E7,0x708E8E82, + 0x28ED9ED4,0x9051F9B1,0x82E4565F,0x3A58313A,0xA78F0983,0x1F336EE6,0x0D86C108,0xB53AA66D, + 0xBD40E1A4,0x05FC86C1,0x1749292F,0xAFF54E4A,0x322276F3,0x8A9E1196,0x982BBE78,0x2097D91D, + 0x78F4C94B,0xC048AE2E,0xD2FD01C0,0x6A4166A5,0xF7965E1C,0x4F2A3979,0x5D9F9697,0xE523F1F2, + 0x4D6B1905,0xF5D77E60,0xE762D18E,0x5FDEB6EB,0xC2098E52,0x7AB5E937,0x680046D9,0xD0BC21BC, + 0x88DF31EA,0x3063568F,0x22D6F961,0x9A6A9E04,0x07BDA6BD,0xBF01C1D8,0xADB46E36,0x15080953, + 0x1D724E9A,0xA5CE29FF,0xB77B8611,0x0FC7E174,0x9210D9CD,0x2AACBEA8,0x38191146,0x80A57623, + 0xD8C66675,0x607A0110,0x72CFAEFE,0xCA73C99B,0x57A4F122,0xEF189647,0xFDAD39A9,0x45115ECC, + 0x764DEE06,0xCEF18963,0xDC44268D,0x64F841E8,0xF92F7951,0x41931E34,0x5326B1DA,0xEB9AD6BF, + 0xB3F9C6E9,0x0B45A18C,0x19F00E62,0xA14C6907,0x3C9B51BE,0x842736DB,0x96929935,0x2E2EFE50, + 0x2654B999,0x9EE8DEFC,0x8C5D7112,0x34E11677,0xA9362ECE,0x118A49AB,0x033FE645,0xBB838120, + 0xE3E09176,0x5B5CF613,0x49E959FD,0xF1553E98,0x6C820621,0xD43E6144,0xC68BCEAA,0x7E37A9CF, + 0xD67F4138,0x6EC3265D,0x7C7689B3,0xC4CAEED6,0x591DD66F,0xE1A1B10A,0xF3141EE4,0x4BA87981, + 0x13CB69D7,0xAB770EB2,0xB9C2A15C,0x017EC639,0x9CA9FE80,0x241599E5,0x36A0360B,0x8E1C516E, + 0x866616A7,0x3EDA71C2,0x2C6FDE2C,0x94D3B949,0x090481F0,0xB1B8E695,0xA30D497B,0x1BB12E1E, + 0x43D23E48,0xFB6E592D,0xE9DBF6C3,0x516791A6,0xCCB0A91F,0x740CCE7A,0x66B96194,0xDE0506F1, + } +#endif // defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_16) +#if defined (CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_16) + // beyond this point only relevant for Slicing-by-8 and Slicing-by-16 + ,{ + 0x00000000,0x3D6029B0,0x7AC05360,0x47A07AD0,0xF580A6C0,0xC8E08F70,0x8F40F5A0,0xB220DC10, + 0x30704BC1,0x0D106271,0x4AB018A1,0x77D03111,0xC5F0ED01,0xF890C4B1,0xBF30BE61,0x825097D1, + 0x60E09782,0x5D80BE32,0x1A20C4E2,0x2740ED52,0x95603142,0xA80018F2,0xEFA06222,0xD2C04B92, + 0x5090DC43,0x6DF0F5F3,0x2A508F23,0x1730A693,0xA5107A83,0x98705333,0xDFD029E3,0xE2B00053, + 0xC1C12F04,0xFCA106B4,0xBB017C64,0x866155D4,0x344189C4,0x0921A074,0x4E81DAA4,0x73E1F314, + 0xF1B164C5,0xCCD14D75,0x8B7137A5,0xB6111E15,0x0431C205,0x3951EBB5,0x7EF19165,0x4391B8D5, + 0xA121B886,0x9C419136,0xDBE1EBE6,0xE681C256,0x54A11E46,0x69C137F6,0x2E614D26,0x13016496, + 0x9151F347,0xAC31DAF7,0xEB91A027,0xD6F18997,0x64D15587,0x59B17C37,0x1E1106E7,0x23712F57, + 0x58F35849,0x659371F9,0x22330B29,0x1F532299,0xAD73FE89,0x9013D739,0xD7B3ADE9,0xEAD38459, + 0x68831388,0x55E33A38,0x124340E8,0x2F236958,0x9D03B548,0xA0639CF8,0xE7C3E628,0xDAA3CF98, + 0x3813CFCB,0x0573E67B,0x42D39CAB,0x7FB3B51B,0xCD93690B,0xF0F340BB,0xB7533A6B,0x8A3313DB, + 0x0863840A,0x3503ADBA,0x72A3D76A,0x4FC3FEDA,0xFDE322CA,0xC0830B7A,0x872371AA,0xBA43581A, + 0x9932774D,0xA4525EFD,0xE3F2242D,0xDE920D9D,0x6CB2D18D,0x51D2F83D,0x167282ED,0x2B12AB5D, + 0xA9423C8C,0x9422153C,0xD3826FEC,0xEEE2465C,0x5CC29A4C,0x61A2B3FC,0x2602C92C,0x1B62E09C, + 0xF9D2E0CF,0xC4B2C97F,0x8312B3AF,0xBE729A1F,0x0C52460F,0x31326FBF,0x7692156F,0x4BF23CDF, + 0xC9A2AB0E,0xF4C282BE,0xB362F86E,0x8E02D1DE,0x3C220DCE,0x0142247E,0x46E25EAE,0x7B82771E, + 0xB1E6B092,0x8C869922,0xCB26E3F2,0xF646CA42,0x44661652,0x79063FE2,0x3EA64532,0x03C66C82, + 0x8196FB53,0xBCF6D2E3,0xFB56A833,0xC6368183,0x74165D93,0x49767423,0x0ED60EF3,0x33B62743, + 0xD1062710,0xEC660EA0,0xABC67470,0x96A65DC0,0x248681D0,0x19E6A860,0x5E46D2B0,0x6326FB00, + 0xE1766CD1,0xDC164561,0x9BB63FB1,0xA6D61601,0x14F6CA11,0x2996E3A1,0x6E369971,0x5356B0C1, + 0x70279F96,0x4D47B626,0x0AE7CCF6,0x3787E546,0x85A73956,0xB8C710E6,0xFF676A36,0xC2074386, + 0x4057D457,0x7D37FDE7,0x3A978737,0x07F7AE87,0xB5D77297,0x88B75B27,0xCF1721F7,0xF2770847, + 0x10C70814,0x2DA721A4,0x6A075B74,0x576772C4,0xE547AED4,0xD8278764,0x9F87FDB4,0xA2E7D404, + 0x20B743D5,0x1DD76A65,0x5A7710B5,0x67173905,0xD537E515,0xE857CCA5,0xAFF7B675,0x92979FC5, + 0xE915E8DB,0xD475C16B,0x93D5BBBB,0xAEB5920B,0x1C954E1B,0x21F567AB,0x66551D7B,0x5B3534CB, + 0xD965A31A,0xE4058AAA,0xA3A5F07A,0x9EC5D9CA,0x2CE505DA,0x11852C6A,0x562556BA,0x6B457F0A, + 0x89F57F59,0xB49556E9,0xF3352C39,0xCE550589,0x7C75D999,0x4115F029,0x06B58AF9,0x3BD5A349, + 0xB9853498,0x84E51D28,0xC34567F8,0xFE254E48,0x4C059258,0x7165BBE8,0x36C5C138,0x0BA5E888, + 0x28D4C7DF,0x15B4EE6F,0x521494BF,0x6F74BD0F,0xDD54611F,0xE03448AF,0xA794327F,0x9AF41BCF, + 0x18A48C1E,0x25C4A5AE,0x6264DF7E,0x5F04F6CE,0xED242ADE,0xD044036E,0x97E479BE,0xAA84500E, + 0x4834505D,0x755479ED,0x32F4033D,0x0F942A8D,0xBDB4F69D,0x80D4DF2D,0xC774A5FD,0xFA148C4D, + 0x78441B9C,0x4524322C,0x028448FC,0x3FE4614C,0x8DC4BD5C,0xB0A494EC,0xF704EE3C,0xCA64C78C, + }, + + { + 0x00000000,0xCB5CD3A5,0x4DC8A10B,0x869472AE,0x9B914216,0x50CD91B3,0xD659E31D,0x1D0530B8, + 0xEC53826D,0x270F51C8,0xA19B2366,0x6AC7F0C3,0x77C2C07B,0xBC9E13DE,0x3A0A6170,0xF156B2D5, + 0x03D6029B,0xC88AD13E,0x4E1EA390,0x85427035,0x9847408D,0x531B9328,0xD58FE186,0x1ED33223, + 0xEF8580F6,0x24D95353,0xA24D21FD,0x6911F258,0x7414C2E0,0xBF481145,0x39DC63EB,0xF280B04E, + 0x07AC0536,0xCCF0D693,0x4A64A43D,0x81387798,0x9C3D4720,0x57619485,0xD1F5E62B,0x1AA9358E, + 0xEBFF875B,0x20A354FE,0xA6372650,0x6D6BF5F5,0x706EC54D,0xBB3216E8,0x3DA66446,0xF6FAB7E3, + 0x047A07AD,0xCF26D408,0x49B2A6A6,0x82EE7503,0x9FEB45BB,0x54B7961E,0xD223E4B0,0x197F3715, + 0xE82985C0,0x23755665,0xA5E124CB,0x6EBDF76E,0x73B8C7D6,0xB8E41473,0x3E7066DD,0xF52CB578, + 0x0F580A6C,0xC404D9C9,0x4290AB67,0x89CC78C2,0x94C9487A,0x5F959BDF,0xD901E971,0x125D3AD4, + 0xE30B8801,0x28575BA4,0xAEC3290A,0x659FFAAF,0x789ACA17,0xB3C619B2,0x35526B1C,0xFE0EB8B9, + 0x0C8E08F7,0xC7D2DB52,0x4146A9FC,0x8A1A7A59,0x971F4AE1,0x5C439944,0xDAD7EBEA,0x118B384F, + 0xE0DD8A9A,0x2B81593F,0xAD152B91,0x6649F834,0x7B4CC88C,0xB0101B29,0x36846987,0xFDD8BA22, + 0x08F40F5A,0xC3A8DCFF,0x453CAE51,0x8E607DF4,0x93654D4C,0x58399EE9,0xDEADEC47,0x15F13FE2, + 0xE4A78D37,0x2FFB5E92,0xA96F2C3C,0x6233FF99,0x7F36CF21,0xB46A1C84,0x32FE6E2A,0xF9A2BD8F, + 0x0B220DC1,0xC07EDE64,0x46EAACCA,0x8DB67F6F,0x90B34FD7,0x5BEF9C72,0xDD7BEEDC,0x16273D79, + 0xE7718FAC,0x2C2D5C09,0xAAB92EA7,0x61E5FD02,0x7CE0CDBA,0xB7BC1E1F,0x31286CB1,0xFA74BF14, + 0x1EB014D8,0xD5ECC77D,0x5378B5D3,0x98246676,0x852156CE,0x4E7D856B,0xC8E9F7C5,0x03B52460, + 0xF2E396B5,0x39BF4510,0xBF2B37BE,0x7477E41B,0x6972D4A3,0xA22E0706,0x24BA75A8,0xEFE6A60D, + 0x1D661643,0xD63AC5E6,0x50AEB748,0x9BF264ED,0x86F75455,0x4DAB87F0,0xCB3FF55E,0x006326FB, + 0xF135942E,0x3A69478B,0xBCFD3525,0x77A1E680,0x6AA4D638,0xA1F8059D,0x276C7733,0xEC30A496, + 0x191C11EE,0xD240C24B,0x54D4B0E5,0x9F886340,0x828D53F8,0x49D1805D,0xCF45F2F3,0x04192156, + 0xF54F9383,0x3E134026,0xB8873288,0x73DBE12D,0x6EDED195,0xA5820230,0x2316709E,0xE84AA33B, + 0x1ACA1375,0xD196C0D0,0x5702B27E,0x9C5E61DB,0x815B5163,0x4A0782C6,0xCC93F068,0x07CF23CD, + 0xF6999118,0x3DC542BD,0xBB513013,0x700DE3B6,0x6D08D30E,0xA65400AB,0x20C07205,0xEB9CA1A0, + 0x11E81EB4,0xDAB4CD11,0x5C20BFBF,0x977C6C1A,0x8A795CA2,0x41258F07,0xC7B1FDA9,0x0CED2E0C, + 0xFDBB9CD9,0x36E74F7C,0xB0733DD2,0x7B2FEE77,0x662ADECF,0xAD760D6A,0x2BE27FC4,0xE0BEAC61, + 0x123E1C2F,0xD962CF8A,0x5FF6BD24,0x94AA6E81,0x89AF5E39,0x42F38D9C,0xC467FF32,0x0F3B2C97, + 0xFE6D9E42,0x35314DE7,0xB3A53F49,0x78F9ECEC,0x65FCDC54,0xAEA00FF1,0x28347D5F,0xE368AEFA, + 0x16441B82,0xDD18C827,0x5B8CBA89,0x90D0692C,0x8DD55994,0x46898A31,0xC01DF89F,0x0B412B3A, + 0xFA1799EF,0x314B4A4A,0xB7DF38E4,0x7C83EB41,0x6186DBF9,0xAADA085C,0x2C4E7AF2,0xE712A957, + 0x15921919,0xDECECABC,0x585AB812,0x93066BB7,0x8E035B0F,0x455F88AA,0xC3CBFA04,0x089729A1, + 0xF9C19B74,0x329D48D1,0xB4093A7F,0x7F55E9DA,0x6250D962,0xA90C0AC7,0x2F987869,0xE4C4ABCC, + }, + + { + 0x00000000,0xA6770BB4,0x979F1129,0x31E81A9D,0xF44F2413,0x52382FA7,0x63D0353A,0xC5A73E8E, + 0x33EF4E67,0x959845D3,0xA4705F4E,0x020754FA,0xC7A06A74,0x61D761C0,0x503F7B5D,0xF64870E9, + 0x67DE9CCE,0xC1A9977A,0xF0418DE7,0x56368653,0x9391B8DD,0x35E6B369,0x040EA9F4,0xA279A240, + 0x5431D2A9,0xF246D91D,0xC3AEC380,0x65D9C834,0xA07EF6BA,0x0609FD0E,0x37E1E793,0x9196EC27, + 0xCFBD399C,0x69CA3228,0x582228B5,0xFE552301,0x3BF21D8F,0x9D85163B,0xAC6D0CA6,0x0A1A0712, + 0xFC5277FB,0x5A257C4F,0x6BCD66D2,0xCDBA6D66,0x081D53E8,0xAE6A585C,0x9F8242C1,0x39F54975, + 0xA863A552,0x0E14AEE6,0x3FFCB47B,0x998BBFCF,0x5C2C8141,0xFA5B8AF5,0xCBB39068,0x6DC49BDC, + 0x9B8CEB35,0x3DFBE081,0x0C13FA1C,0xAA64F1A8,0x6FC3CF26,0xC9B4C492,0xF85CDE0F,0x5E2BD5BB, + 0x440B7579,0xE27C7ECD,0xD3946450,0x75E36FE4,0xB044516A,0x16335ADE,0x27DB4043,0x81AC4BF7, + 0x77E43B1E,0xD19330AA,0xE07B2A37,0x460C2183,0x83AB1F0D,0x25DC14B9,0x14340E24,0xB2430590, + 0x23D5E9B7,0x85A2E203,0xB44AF89E,0x123DF32A,0xD79ACDA4,0x71EDC610,0x4005DC8D,0xE672D739, + 0x103AA7D0,0xB64DAC64,0x87A5B6F9,0x21D2BD4D,0xE47583C3,0x42028877,0x73EA92EA,0xD59D995E, + 0x8BB64CE5,0x2DC14751,0x1C295DCC,0xBA5E5678,0x7FF968F6,0xD98E6342,0xE86679DF,0x4E11726B, + 0xB8590282,0x1E2E0936,0x2FC613AB,0x89B1181F,0x4C162691,0xEA612D25,0xDB8937B8,0x7DFE3C0C, + 0xEC68D02B,0x4A1FDB9F,0x7BF7C102,0xDD80CAB6,0x1827F438,0xBE50FF8C,0x8FB8E511,0x29CFEEA5, + 0xDF879E4C,0x79F095F8,0x48188F65,0xEE6F84D1,0x2BC8BA5F,0x8DBFB1EB,0xBC57AB76,0x1A20A0C2, + 0x8816EAF2,0x2E61E146,0x1F89FBDB,0xB9FEF06F,0x7C59CEE1,0xDA2EC555,0xEBC6DFC8,0x4DB1D47C, + 0xBBF9A495,0x1D8EAF21,0x2C66B5BC,0x8A11BE08,0x4FB68086,0xE9C18B32,0xD82991AF,0x7E5E9A1B, + 0xEFC8763C,0x49BF7D88,0x78576715,0xDE206CA1,0x1B87522F,0xBDF0599B,0x8C184306,0x2A6F48B2, + 0xDC27385B,0x7A5033EF,0x4BB82972,0xEDCF22C6,0x28681C48,0x8E1F17FC,0xBFF70D61,0x198006D5, + 0x47ABD36E,0xE1DCD8DA,0xD034C247,0x7643C9F3,0xB3E4F77D,0x1593FCC9,0x247BE654,0x820CEDE0, + 0x74449D09,0xD23396BD,0xE3DB8C20,0x45AC8794,0x800BB91A,0x267CB2AE,0x1794A833,0xB1E3A387, + 0x20754FA0,0x86024414,0xB7EA5E89,0x119D553D,0xD43A6BB3,0x724D6007,0x43A57A9A,0xE5D2712E, + 0x139A01C7,0xB5ED0A73,0x840510EE,0x22721B5A,0xE7D525D4,0x41A22E60,0x704A34FD,0xD63D3F49, + 0xCC1D9F8B,0x6A6A943F,0x5B828EA2,0xFDF58516,0x3852BB98,0x9E25B02C,0xAFCDAAB1,0x09BAA105, + 0xFFF2D1EC,0x5985DA58,0x686DC0C5,0xCE1ACB71,0x0BBDF5FF,0xADCAFE4B,0x9C22E4D6,0x3A55EF62, + 0xABC30345,0x0DB408F1,0x3C5C126C,0x9A2B19D8,0x5F8C2756,0xF9FB2CE2,0xC813367F,0x6E643DCB, + 0x982C4D22,0x3E5B4696,0x0FB35C0B,0xA9C457BF,0x6C636931,0xCA146285,0xFBFC7818,0x5D8B73AC, + 0x03A0A617,0xA5D7ADA3,0x943FB73E,0x3248BC8A,0xF7EF8204,0x519889B0,0x6070932D,0xC6079899, + 0x304FE870,0x9638E3C4,0xA7D0F959,0x01A7F2ED,0xC400CC63,0x6277C7D7,0x539FDD4A,0xF5E8D6FE, + 0x647E3AD9,0xC209316D,0xF3E12BF0,0x55962044,0x90311ECA,0x3646157E,0x07AE0FE3,0xA1D90457, + 0x579174BE,0xF1E67F0A,0xC00E6597,0x66796E23,0xA3DE50AD,0x05A95B19,0x34414184,0x92364A30, + }, + + { + 0x00000000,0xCCAA009E,0x4225077D,0x8E8F07E3,0x844A0EFA,0x48E00E64,0xC66F0987,0x0AC50919, + 0xD3E51BB5,0x1F4F1B2B,0x91C01CC8,0x5D6A1C56,0x57AF154F,0x9B0515D1,0x158A1232,0xD92012AC, + 0x7CBB312B,0xB01131B5,0x3E9E3656,0xF23436C8,0xF8F13FD1,0x345B3F4F,0xBAD438AC,0x767E3832, + 0xAF5E2A9E,0x63F42A00,0xED7B2DE3,0x21D12D7D,0x2B142464,0xE7BE24FA,0x69312319,0xA59B2387, + 0xF9766256,0x35DC62C8,0xBB53652B,0x77F965B5,0x7D3C6CAC,0xB1966C32,0x3F196BD1,0xF3B36B4F, + 0x2A9379E3,0xE639797D,0x68B67E9E,0xA41C7E00,0xAED97719,0x62737787,0xECFC7064,0x205670FA, + 0x85CD537D,0x496753E3,0xC7E85400,0x0B42549E,0x01875D87,0xCD2D5D19,0x43A25AFA,0x8F085A64, + 0x562848C8,0x9A824856,0x140D4FB5,0xD8A74F2B,0xD2624632,0x1EC846AC,0x9047414F,0x5CED41D1, + 0x299DC2ED,0xE537C273,0x6BB8C590,0xA712C50E,0xADD7CC17,0x617DCC89,0xEFF2CB6A,0x2358CBF4, + 0xFA78D958,0x36D2D9C6,0xB85DDE25,0x74F7DEBB,0x7E32D7A2,0xB298D73C,0x3C17D0DF,0xF0BDD041, + 0x5526F3C6,0x998CF358,0x1703F4BB,0xDBA9F425,0xD16CFD3C,0x1DC6FDA2,0x9349FA41,0x5FE3FADF, + 0x86C3E873,0x4A69E8ED,0xC4E6EF0E,0x084CEF90,0x0289E689,0xCE23E617,0x40ACE1F4,0x8C06E16A, + 0xD0EBA0BB,0x1C41A025,0x92CEA7C6,0x5E64A758,0x54A1AE41,0x980BAEDF,0x1684A93C,0xDA2EA9A2, + 0x030EBB0E,0xCFA4BB90,0x412BBC73,0x8D81BCED,0x8744B5F4,0x4BEEB56A,0xC561B289,0x09CBB217, + 0xAC509190,0x60FA910E,0xEE7596ED,0x22DF9673,0x281A9F6A,0xE4B09FF4,0x6A3F9817,0xA6959889, + 0x7FB58A25,0xB31F8ABB,0x3D908D58,0xF13A8DC6,0xFBFF84DF,0x37558441,0xB9DA83A2,0x7570833C, + 0x533B85DA,0x9F918544,0x111E82A7,0xDDB48239,0xD7718B20,0x1BDB8BBE,0x95548C5D,0x59FE8CC3, + 0x80DE9E6F,0x4C749EF1,0xC2FB9912,0x0E51998C,0x04949095,0xC83E900B,0x46B197E8,0x8A1B9776, + 0x2F80B4F1,0xE32AB46F,0x6DA5B38C,0xA10FB312,0xABCABA0B,0x6760BA95,0xE9EFBD76,0x2545BDE8, + 0xFC65AF44,0x30CFAFDA,0xBE40A839,0x72EAA8A7,0x782FA1BE,0xB485A120,0x3A0AA6C3,0xF6A0A65D, + 0xAA4DE78C,0x66E7E712,0xE868E0F1,0x24C2E06F,0x2E07E976,0xE2ADE9E8,0x6C22EE0B,0xA088EE95, + 0x79A8FC39,0xB502FCA7,0x3B8DFB44,0xF727FBDA,0xFDE2F2C3,0x3148F25D,0xBFC7F5BE,0x736DF520, + 0xD6F6D6A7,0x1A5CD639,0x94D3D1DA,0x5879D144,0x52BCD85D,0x9E16D8C3,0x1099DF20,0xDC33DFBE, + 0x0513CD12,0xC9B9CD8C,0x4736CA6F,0x8B9CCAF1,0x8159C3E8,0x4DF3C376,0xC37CC495,0x0FD6C40B, + 0x7AA64737,0xB60C47A9,0x3883404A,0xF42940D4,0xFEEC49CD,0x32464953,0xBCC94EB0,0x70634E2E, + 0xA9435C82,0x65E95C1C,0xEB665BFF,0x27CC5B61,0x2D095278,0xE1A352E6,0x6F2C5505,0xA386559B, + 0x061D761C,0xCAB77682,0x44387161,0x889271FF,0x825778E6,0x4EFD7878,0xC0727F9B,0x0CD87F05, + 0xD5F86DA9,0x19526D37,0x97DD6AD4,0x5B776A4A,0x51B26353,0x9D1863CD,0x1397642E,0xDF3D64B0, + 0x83D02561,0x4F7A25FF,0xC1F5221C,0x0D5F2282,0x079A2B9B,0xCB302B05,0x45BF2CE6,0x89152C78, + 0x50353ED4,0x9C9F3E4A,0x121039A9,0xDEBA3937,0xD47F302E,0x18D530B0,0x965A3753,0x5AF037CD, + 0xFF6B144A,0x33C114D4,0xBD4E1337,0x71E413A9,0x7B211AB0,0xB78B1A2E,0x39041DCD,0xF5AE1D53, + 0x2C8E0FFF,0xE0240F61,0x6EAB0882,0xA201081C,0xA8C40105,0x646E019B,0xEAE10678,0x264B06E6, + } +#endif // CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 || CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 + // beyond this point only relevant for Slicing-by-16 + ,{ + 0x00000000,0x177B1443,0x2EF62886,0x398D3CC5,0x5DEC510C,0x4A97454F,0x731A798A,0x64616DC9, + 0xBBD8A218,0xACA3B65B,0x952E8A9E,0x82559EDD,0xE634F314,0xF14FE757,0xC8C2DB92,0xDFB9CFD1, + 0xACC04271,0xBBBB5632,0x82366AF7,0x954D7EB4,0xF12C137D,0xE657073E,0xDFDA3BFB,0xC8A12FB8, + 0x1718E069,0x0063F42A,0x39EEC8EF,0x2E95DCAC,0x4AF4B165,0x5D8FA526,0x640299E3,0x73798DA0, + 0x82F182A3,0x958A96E0,0xAC07AA25,0xBB7CBE66,0xDF1DD3AF,0xC866C7EC,0xF1EBFB29,0xE690EF6A, + 0x392920BB,0x2E5234F8,0x17DF083D,0x00A41C7E,0x64C571B7,0x73BE65F4,0x4A335931,0x5D484D72, + 0x2E31C0D2,0x394AD491,0x00C7E854,0x17BCFC17,0x73DD91DE,0x64A6859D,0x5D2BB958,0x4A50AD1B, + 0x95E962CA,0x82927689,0xBB1F4A4C,0xAC645E0F,0xC80533C6,0xDF7E2785,0xE6F31B40,0xF1880F03, + 0xDE920307,0xC9E91744,0xF0642B81,0xE71F3FC2,0x837E520B,0x94054648,0xAD887A8D,0xBAF36ECE, + 0x654AA11F,0x7231B55C,0x4BBC8999,0x5CC79DDA,0x38A6F013,0x2FDDE450,0x1650D895,0x012BCCD6, + 0x72524176,0x65295535,0x5CA469F0,0x4BDF7DB3,0x2FBE107A,0x38C50439,0x014838FC,0x16332CBF, + 0xC98AE36E,0xDEF1F72D,0xE77CCBE8,0xF007DFAB,0x9466B262,0x831DA621,0xBA909AE4,0xADEB8EA7, + 0x5C6381A4,0x4B1895E7,0x7295A922,0x65EEBD61,0x018FD0A8,0x16F4C4EB,0x2F79F82E,0x3802EC6D, + 0xE7BB23BC,0xF0C037FF,0xC94D0B3A,0xDE361F79,0xBA5772B0,0xAD2C66F3,0x94A15A36,0x83DA4E75, + 0xF0A3C3D5,0xE7D8D796,0xDE55EB53,0xC92EFF10,0xAD4F92D9,0xBA34869A,0x83B9BA5F,0x94C2AE1C, + 0x4B7B61CD,0x5C00758E,0x658D494B,0x72F65D08,0x169730C1,0x01EC2482,0x38611847,0x2F1A0C04, + 0x6655004F,0x712E140C,0x48A328C9,0x5FD83C8A,0x3BB95143,0x2CC24500,0x154F79C5,0x02346D86, + 0xDD8DA257,0xCAF6B614,0xF37B8AD1,0xE4009E92,0x8061F35B,0x971AE718,0xAE97DBDD,0xB9ECCF9E, + 0xCA95423E,0xDDEE567D,0xE4636AB8,0xF3187EFB,0x97791332,0x80020771,0xB98F3BB4,0xAEF42FF7, + 0x714DE026,0x6636F465,0x5FBBC8A0,0x48C0DCE3,0x2CA1B12A,0x3BDAA569,0x025799AC,0x152C8DEF, + 0xE4A482EC,0xF3DF96AF,0xCA52AA6A,0xDD29BE29,0xB948D3E0,0xAE33C7A3,0x97BEFB66,0x80C5EF25, + 0x5F7C20F4,0x480734B7,0x718A0872,0x66F11C31,0x029071F8,0x15EB65BB,0x2C66597E,0x3B1D4D3D, + 0x4864C09D,0x5F1FD4DE,0x6692E81B,0x71E9FC58,0x15889191,0x02F385D2,0x3B7EB917,0x2C05AD54, + 0xF3BC6285,0xE4C776C6,0xDD4A4A03,0xCA315E40,0xAE503389,0xB92B27CA,0x80A61B0F,0x97DD0F4C, + 0xB8C70348,0xAFBC170B,0x96312BCE,0x814A3F8D,0xE52B5244,0xF2504607,0xCBDD7AC2,0xDCA66E81, + 0x031FA150,0x1464B513,0x2DE989D6,0x3A929D95,0x5EF3F05C,0x4988E41F,0x7005D8DA,0x677ECC99, + 0x14074139,0x037C557A,0x3AF169BF,0x2D8A7DFC,0x49EB1035,0x5E900476,0x671D38B3,0x70662CF0, + 0xAFDFE321,0xB8A4F762,0x8129CBA7,0x9652DFE4,0xF233B22D,0xE548A66E,0xDCC59AAB,0xCBBE8EE8, + 0x3A3681EB,0x2D4D95A8,0x14C0A96D,0x03BBBD2E,0x67DAD0E7,0x70A1C4A4,0x492CF861,0x5E57EC22, + 0x81EE23F3,0x969537B0,0xAF180B75,0xB8631F36,0xDC0272FF,0xCB7966BC,0xF2F45A79,0xE58F4E3A, + 0x96F6C39A,0x818DD7D9,0xB800EB1C,0xAF7BFF5F,0xCB1A9296,0xDC6186D5,0xE5ECBA10,0xF297AE53, + 0x2D2E6182,0x3A5575C1,0x03D84904,0x14A35D47,0x70C2308E,0x67B924CD,0x5E341808,0x494F0C4B, + }, + + { + 0x00000000,0xEFC26B3E,0x04F5D03D,0xEB37BB03,0x09EBA07A,0xE629CB44,0x0D1E7047,0xE2DC1B79, + 0x13D740F4,0xFC152BCA,0x172290C9,0xF8E0FBF7,0x1A3CE08E,0xF5FE8BB0,0x1EC930B3,0xF10B5B8D, + 0x27AE81E8,0xC86CEAD6,0x235B51D5,0xCC993AEB,0x2E452192,0xC1874AAC,0x2AB0F1AF,0xC5729A91, + 0x3479C11C,0xDBBBAA22,0x308C1121,0xDF4E7A1F,0x3D926166,0xD2500A58,0x3967B15B,0xD6A5DA65, + 0x4F5D03D0,0xA09F68EE,0x4BA8D3ED,0xA46AB8D3,0x46B6A3AA,0xA974C894,0x42437397,0xAD8118A9, + 0x5C8A4324,0xB348281A,0x587F9319,0xB7BDF827,0x5561E35E,0xBAA38860,0x51943363,0xBE56585D, + 0x68F38238,0x8731E906,0x6C065205,0x83C4393B,0x61182242,0x8EDA497C,0x65EDF27F,0x8A2F9941, + 0x7B24C2CC,0x94E6A9F2,0x7FD112F1,0x901379CF,0x72CF62B6,0x9D0D0988,0x763AB28B,0x99F8D9B5, + 0x9EBA07A0,0x71786C9E,0x9A4FD79D,0x758DBCA3,0x9751A7DA,0x7893CCE4,0x93A477E7,0x7C661CD9, + 0x8D6D4754,0x62AF2C6A,0x89989769,0x665AFC57,0x8486E72E,0x6B448C10,0x80733713,0x6FB15C2D, + 0xB9148648,0x56D6ED76,0xBDE15675,0x52233D4B,0xB0FF2632,0x5F3D4D0C,0xB40AF60F,0x5BC89D31, + 0xAAC3C6BC,0x4501AD82,0xAE361681,0x41F47DBF,0xA32866C6,0x4CEA0DF8,0xA7DDB6FB,0x481FDDC5, + 0xD1E70470,0x3E256F4E,0xD512D44D,0x3AD0BF73,0xD80CA40A,0x37CECF34,0xDCF97437,0x333B1F09, + 0xC2304484,0x2DF22FBA,0xC6C594B9,0x2907FF87,0xCBDBE4FE,0x24198FC0,0xCF2E34C3,0x20EC5FFD, + 0xF6498598,0x198BEEA6,0xF2BC55A5,0x1D7E3E9B,0xFFA225E2,0x10604EDC,0xFB57F5DF,0x14959EE1, + 0xE59EC56C,0x0A5CAE52,0xE16B1551,0x0EA97E6F,0xEC756516,0x03B70E28,0xE880B52B,0x0742DE15, + 0xE6050901,0x09C7623F,0xE2F0D93C,0x0D32B202,0xEFEEA97B,0x002CC245,0xEB1B7946,0x04D91278, + 0xF5D249F5,0x1A1022CB,0xF12799C8,0x1EE5F2F6,0xFC39E98F,0x13FB82B1,0xF8CC39B2,0x170E528C, + 0xC1AB88E9,0x2E69E3D7,0xC55E58D4,0x2A9C33EA,0xC8402893,0x278243AD,0xCCB5F8AE,0x23779390, + 0xD27CC81D,0x3DBEA323,0xD6891820,0x394B731E,0xDB976867,0x34550359,0xDF62B85A,0x30A0D364, + 0xA9580AD1,0x469A61EF,0xADADDAEC,0x426FB1D2,0xA0B3AAAB,0x4F71C195,0xA4467A96,0x4B8411A8, + 0xBA8F4A25,0x554D211B,0xBE7A9A18,0x51B8F126,0xB364EA5F,0x5CA68161,0xB7913A62,0x5853515C, + 0x8EF68B39,0x6134E007,0x8A035B04,0x65C1303A,0x871D2B43,0x68DF407D,0x83E8FB7E,0x6C2A9040, + 0x9D21CBCD,0x72E3A0F3,0x99D41BF0,0x761670CE,0x94CA6BB7,0x7B080089,0x903FBB8A,0x7FFDD0B4, + 0x78BF0EA1,0x977D659F,0x7C4ADE9C,0x9388B5A2,0x7154AEDB,0x9E96C5E5,0x75A17EE6,0x9A6315D8, + 0x6B684E55,0x84AA256B,0x6F9D9E68,0x805FF556,0x6283EE2F,0x8D418511,0x66763E12,0x89B4552C, + 0x5F118F49,0xB0D3E477,0x5BE45F74,0xB426344A,0x56FA2F33,0xB938440D,0x520FFF0E,0xBDCD9430, + 0x4CC6CFBD,0xA304A483,0x48331F80,0xA7F174BE,0x452D6FC7,0xAAEF04F9,0x41D8BFFA,0xAE1AD4C4, + 0x37E20D71,0xD820664F,0x3317DD4C,0xDCD5B672,0x3E09AD0B,0xD1CBC635,0x3AFC7D36,0xD53E1608, + 0x24354D85,0xCBF726BB,0x20C09DB8,0xCF02F686,0x2DDEEDFF,0xC21C86C1,0x292B3DC2,0xC6E956FC, + 0x104C8C99,0xFF8EE7A7,0x14B95CA4,0xFB7B379A,0x19A72CE3,0xF66547DD,0x1D52FCDE,0xF29097E0, + 0x039BCC6D,0xEC59A753,0x076E1C50,0xE8AC776E,0x0A706C17,0xE5B20729,0x0E85BC2A,0xE147D714, + }, + + { + 0x00000000,0xC18EDFC0,0x586CB9C1,0x99E26601,0xB0D97382,0x7157AC42,0xE8B5CA43,0x293B1583, + 0xBAC3E145,0x7B4D3E85,0xE2AF5884,0x23218744,0x0A1A92C7,0xCB944D07,0x52762B06,0x93F8F4C6, + 0xAEF6C4CB,0x6F781B0B,0xF69A7D0A,0x3714A2CA,0x1E2FB749,0xDFA16889,0x46430E88,0x87CDD148, + 0x1435258E,0xD5BBFA4E,0x4C599C4F,0x8DD7438F,0xA4EC560C,0x656289CC,0xFC80EFCD,0x3D0E300D, + 0x869C8FD7,0x47125017,0xDEF03616,0x1F7EE9D6,0x3645FC55,0xF7CB2395,0x6E294594,0xAFA79A54, + 0x3C5F6E92,0xFDD1B152,0x6433D753,0xA5BD0893,0x8C861D10,0x4D08C2D0,0xD4EAA4D1,0x15647B11, + 0x286A4B1C,0xE9E494DC,0x7006F2DD,0xB1882D1D,0x98B3389E,0x593DE75E,0xC0DF815F,0x01515E9F, + 0x92A9AA59,0x53277599,0xCAC51398,0x0B4BCC58,0x2270D9DB,0xE3FE061B,0x7A1C601A,0xBB92BFDA, + 0xD64819EF,0x17C6C62F,0x8E24A02E,0x4FAA7FEE,0x66916A6D,0xA71FB5AD,0x3EFDD3AC,0xFF730C6C, + 0x6C8BF8AA,0xAD05276A,0x34E7416B,0xF5699EAB,0xDC528B28,0x1DDC54E8,0x843E32E9,0x45B0ED29, + 0x78BEDD24,0xB93002E4,0x20D264E5,0xE15CBB25,0xC867AEA6,0x09E97166,0x900B1767,0x5185C8A7, + 0xC27D3C61,0x03F3E3A1,0x9A1185A0,0x5B9F5A60,0x72A44FE3,0xB32A9023,0x2AC8F622,0xEB4629E2, + 0x50D49638,0x915A49F8,0x08B82FF9,0xC936F039,0xE00DE5BA,0x21833A7A,0xB8615C7B,0x79EF83BB, + 0xEA17777D,0x2B99A8BD,0xB27BCEBC,0x73F5117C,0x5ACE04FF,0x9B40DB3F,0x02A2BD3E,0xC32C62FE, + 0xFE2252F3,0x3FAC8D33,0xA64EEB32,0x67C034F2,0x4EFB2171,0x8F75FEB1,0x169798B0,0xD7194770, + 0x44E1B3B6,0x856F6C76,0x1C8D0A77,0xDD03D5B7,0xF438C034,0x35B61FF4,0xAC5479F5,0x6DDAA635, + 0x77E1359F,0xB66FEA5F,0x2F8D8C5E,0xEE03539E,0xC738461D,0x06B699DD,0x9F54FFDC,0x5EDA201C, + 0xCD22D4DA,0x0CAC0B1A,0x954E6D1B,0x54C0B2DB,0x7DFBA758,0xBC757898,0x25971E99,0xE419C159, + 0xD917F154,0x18992E94,0x817B4895,0x40F59755,0x69CE82D6,0xA8405D16,0x31A23B17,0xF02CE4D7, + 0x63D41011,0xA25ACFD1,0x3BB8A9D0,0xFA367610,0xD30D6393,0x1283BC53,0x8B61DA52,0x4AEF0592, + 0xF17DBA48,0x30F36588,0xA9110389,0x689FDC49,0x41A4C9CA,0x802A160A,0x19C8700B,0xD846AFCB, + 0x4BBE5B0D,0x8A3084CD,0x13D2E2CC,0xD25C3D0C,0xFB67288F,0x3AE9F74F,0xA30B914E,0x62854E8E, + 0x5F8B7E83,0x9E05A143,0x07E7C742,0xC6691882,0xEF520D01,0x2EDCD2C1,0xB73EB4C0,0x76B06B00, + 0xE5489FC6,0x24C64006,0xBD242607,0x7CAAF9C7,0x5591EC44,0x941F3384,0x0DFD5585,0xCC738A45, + 0xA1A92C70,0x6027F3B0,0xF9C595B1,0x384B4A71,0x11705FF2,0xD0FE8032,0x491CE633,0x889239F3, + 0x1B6ACD35,0xDAE412F5,0x430674F4,0x8288AB34,0xABB3BEB7,0x6A3D6177,0xF3DF0776,0x3251D8B6, + 0x0F5FE8BB,0xCED1377B,0x5733517A,0x96BD8EBA,0xBF869B39,0x7E0844F9,0xE7EA22F8,0x2664FD38, + 0xB59C09FE,0x7412D63E,0xEDF0B03F,0x2C7E6FFF,0x05457A7C,0xC4CBA5BC,0x5D29C3BD,0x9CA71C7D, + 0x2735A3A7,0xE6BB7C67,0x7F591A66,0xBED7C5A6,0x97ECD025,0x56620FE5,0xCF8069E4,0x0E0EB624, + 0x9DF642E2,0x5C789D22,0xC59AFB23,0x041424E3,0x2D2F3160,0xECA1EEA0,0x754388A1,0xB4CD5761, + 0x89C3676C,0x484DB8AC,0xD1AFDEAD,0x1021016D,0x391A14EE,0xF894CB2E,0x6176AD2F,0xA0F872EF, + 0x33008629,0xF28E59E9,0x6B6C3FE8,0xAAE2E028,0x83D9F5AB,0x42572A6B,0xDBB54C6A,0x1A3B93AA, + }, + + { + 0x00000000,0x9BA54C6F,0xEC3B9E9F,0x779ED2F0,0x03063B7F,0x98A37710,0xEF3DA5E0,0x7498E98F, + 0x060C76FE,0x9DA93A91,0xEA37E861,0x7192A40E,0x050A4D81,0x9EAF01EE,0xE931D31E,0x72949F71, + 0x0C18EDFC,0x97BDA193,0xE0237363,0x7B863F0C,0x0F1ED683,0x94BB9AEC,0xE325481C,0x78800473, + 0x0A149B02,0x91B1D76D,0xE62F059D,0x7D8A49F2,0x0912A07D,0x92B7EC12,0xE5293EE2,0x7E8C728D, + 0x1831DBF8,0x83949797,0xF40A4567,0x6FAF0908,0x1B37E087,0x8092ACE8,0xF70C7E18,0x6CA93277, + 0x1E3DAD06,0x8598E169,0xF2063399,0x69A37FF6,0x1D3B9679,0x869EDA16,0xF10008E6,0x6AA54489, + 0x14293604,0x8F8C7A6B,0xF812A89B,0x63B7E4F4,0x172F0D7B,0x8C8A4114,0xFB1493E4,0x60B1DF8B, + 0x122540FA,0x89800C95,0xFE1EDE65,0x65BB920A,0x11237B85,0x8A8637EA,0xFD18E51A,0x66BDA975, + 0x3063B7F0,0xABC6FB9F,0xDC58296F,0x47FD6500,0x33658C8F,0xA8C0C0E0,0xDF5E1210,0x44FB5E7F, + 0x366FC10E,0xADCA8D61,0xDA545F91,0x41F113FE,0x3569FA71,0xAECCB61E,0xD95264EE,0x42F72881, + 0x3C7B5A0C,0xA7DE1663,0xD040C493,0x4BE588FC,0x3F7D6173,0xA4D82D1C,0xD346FFEC,0x48E3B383, + 0x3A772CF2,0xA1D2609D,0xD64CB26D,0x4DE9FE02,0x3971178D,0xA2D45BE2,0xD54A8912,0x4EEFC57D, + 0x28526C08,0xB3F72067,0xC469F297,0x5FCCBEF8,0x2B545777,0xB0F11B18,0xC76FC9E8,0x5CCA8587, + 0x2E5E1AF6,0xB5FB5699,0xC2658469,0x59C0C806,0x2D582189,0xB6FD6DE6,0xC163BF16,0x5AC6F379, + 0x244A81F4,0xBFEFCD9B,0xC8711F6B,0x53D45304,0x274CBA8B,0xBCE9F6E4,0xCB772414,0x50D2687B, + 0x2246F70A,0xB9E3BB65,0xCE7D6995,0x55D825FA,0x2140CC75,0xBAE5801A,0xCD7B52EA,0x56DE1E85, + 0x60C76FE0,0xFB62238F,0x8CFCF17F,0x1759BD10,0x63C1549F,0xF86418F0,0x8FFACA00,0x145F866F, + 0x66CB191E,0xFD6E5571,0x8AF08781,0x1155CBEE,0x65CD2261,0xFE686E0E,0x89F6BCFE,0x1253F091, + 0x6CDF821C,0xF77ACE73,0x80E41C83,0x1B4150EC,0x6FD9B963,0xF47CF50C,0x83E227FC,0x18476B93, + 0x6AD3F4E2,0xF176B88D,0x86E86A7D,0x1D4D2612,0x69D5CF9D,0xF27083F2,0x85EE5102,0x1E4B1D6D, + 0x78F6B418,0xE353F877,0x94CD2A87,0x0F6866E8,0x7BF08F67,0xE055C308,0x97CB11F8,0x0C6E5D97, + 0x7EFAC2E6,0xE55F8E89,0x92C15C79,0x09641016,0x7DFCF999,0xE659B5F6,0x91C76706,0x0A622B69, + 0x74EE59E4,0xEF4B158B,0x98D5C77B,0x03708B14,0x77E8629B,0xEC4D2EF4,0x9BD3FC04,0x0076B06B, + 0x72E22F1A,0xE9476375,0x9ED9B185,0x057CFDEA,0x71E41465,0xEA41580A,0x9DDF8AFA,0x067AC695, + 0x50A4D810,0xCB01947F,0xBC9F468F,0x273A0AE0,0x53A2E36F,0xC807AF00,0xBF997DF0,0x243C319F, + 0x56A8AEEE,0xCD0DE281,0xBA933071,0x21367C1E,0x55AE9591,0xCE0BD9FE,0xB9950B0E,0x22304761, + 0x5CBC35EC,0xC7197983,0xB087AB73,0x2B22E71C,0x5FBA0E93,0xC41F42FC,0xB381900C,0x2824DC63, + 0x5AB04312,0xC1150F7D,0xB68BDD8D,0x2D2E91E2,0x59B6786D,0xC2133402,0xB58DE6F2,0x2E28AA9D, + 0x489503E8,0xD3304F87,0xA4AE9D77,0x3F0BD118,0x4B933897,0xD03674F8,0xA7A8A608,0x3C0DEA67, + 0x4E997516,0xD53C3979,0xA2A2EB89,0x3907A7E6,0x4D9F4E69,0xD63A0206,0xA1A4D0F6,0x3A019C99, + 0x448DEE14,0xDF28A27B,0xA8B6708B,0x33133CE4,0x478BD56B,0xDC2E9904,0xABB04BF4,0x3015079B, + 0x428198EA,0xD924D485,0xAEBA0675,0x351F4A1A,0x4187A395,0xDA22EFFA,0xADBC3D0A,0x36197165, + }, + + { + 0x00000000,0xDD96D985,0x605CB54B,0xBDCA6CCE,0xC0B96A96,0x1D2FB313,0xA0E5DFDD,0x7D730658, + 0x5A03D36D,0x87950AE8,0x3A5F6626,0xE7C9BFA3,0x9ABAB9FB,0x472C607E,0xFAE60CB0,0x2770D535, + 0xB407A6DA,0x69917F5F,0xD45B1391,0x09CDCA14,0x74BECC4C,0xA92815C9,0x14E27907,0xC974A082, + 0xEE0475B7,0x3392AC32,0x8E58C0FC,0x53CE1979,0x2EBD1F21,0xF32BC6A4,0x4EE1AA6A,0x937773EF, + 0xB37E4BF5,0x6EE89270,0xD322FEBE,0x0EB4273B,0x73C72163,0xAE51F8E6,0x139B9428,0xCE0D4DAD, + 0xE97D9898,0x34EB411D,0x89212DD3,0x54B7F456,0x29C4F20E,0xF4522B8B,0x49984745,0x940E9EC0, + 0x0779ED2F,0xDAEF34AA,0x67255864,0xBAB381E1,0xC7C087B9,0x1A565E3C,0xA79C32F2,0x7A0AEB77, + 0x5D7A3E42,0x80ECE7C7,0x3D268B09,0xE0B0528C,0x9DC354D4,0x40558D51,0xFD9FE19F,0x2009381A, + 0xBD8D91AB,0x601B482E,0xDDD124E0,0x0047FD65,0x7D34FB3D,0xA0A222B8,0x1D684E76,0xC0FE97F3, + 0xE78E42C6,0x3A189B43,0x87D2F78D,0x5A442E08,0x27372850,0xFAA1F1D5,0x476B9D1B,0x9AFD449E, + 0x098A3771,0xD41CEEF4,0x69D6823A,0xB4405BBF,0xC9335DE7,0x14A58462,0xA96FE8AC,0x74F93129, + 0x5389E41C,0x8E1F3D99,0x33D55157,0xEE4388D2,0x93308E8A,0x4EA6570F,0xF36C3BC1,0x2EFAE244, + 0x0EF3DA5E,0xD36503DB,0x6EAF6F15,0xB339B690,0xCE4AB0C8,0x13DC694D,0xAE160583,0x7380DC06, + 0x54F00933,0x8966D0B6,0x34ACBC78,0xE93A65FD,0x944963A5,0x49DFBA20,0xF415D6EE,0x29830F6B, + 0xBAF47C84,0x6762A501,0xDAA8C9CF,0x073E104A,0x7A4D1612,0xA7DBCF97,0x1A11A359,0xC7877ADC, + 0xE0F7AFE9,0x3D61766C,0x80AB1AA2,0x5D3DC327,0x204EC57F,0xFDD81CFA,0x40127034,0x9D84A9B1, + 0xA06A2517,0x7DFCFC92,0xC036905C,0x1DA049D9,0x60D34F81,0xBD459604,0x008FFACA,0xDD19234F, + 0xFA69F67A,0x27FF2FFF,0x9A354331,0x47A39AB4,0x3AD09CEC,0xE7464569,0x5A8C29A7,0x871AF022, + 0x146D83CD,0xC9FB5A48,0x74313686,0xA9A7EF03,0xD4D4E95B,0x094230DE,0xB4885C10,0x691E8595, + 0x4E6E50A0,0x93F88925,0x2E32E5EB,0xF3A43C6E,0x8ED73A36,0x5341E3B3,0xEE8B8F7D,0x331D56F8, + 0x13146EE2,0xCE82B767,0x7348DBA9,0xAEDE022C,0xD3AD0474,0x0E3BDDF1,0xB3F1B13F,0x6E6768BA, + 0x4917BD8F,0x9481640A,0x294B08C4,0xF4DDD141,0x89AED719,0x54380E9C,0xE9F26252,0x3464BBD7, + 0xA713C838,0x7A8511BD,0xC74F7D73,0x1AD9A4F6,0x67AAA2AE,0xBA3C7B2B,0x07F617E5,0xDA60CE60, + 0xFD101B55,0x2086C2D0,0x9D4CAE1E,0x40DA779B,0x3DA971C3,0xE03FA846,0x5DF5C488,0x80631D0D, + 0x1DE7B4BC,0xC0716D39,0x7DBB01F7,0xA02DD872,0xDD5EDE2A,0x00C807AF,0xBD026B61,0x6094B2E4, + 0x47E467D1,0x9A72BE54,0x27B8D29A,0xFA2E0B1F,0x875D0D47,0x5ACBD4C2,0xE701B80C,0x3A976189, + 0xA9E01266,0x7476CBE3,0xC9BCA72D,0x142A7EA8,0x695978F0,0xB4CFA175,0x0905CDBB,0xD493143E, + 0xF3E3C10B,0x2E75188E,0x93BF7440,0x4E29ADC5,0x335AAB9D,0xEECC7218,0x53061ED6,0x8E90C753, + 0xAE99FF49,0x730F26CC,0xCEC54A02,0x13539387,0x6E2095DF,0xB3B64C5A,0x0E7C2094,0xD3EAF911, + 0xF49A2C24,0x290CF5A1,0x94C6996F,0x495040EA,0x342346B2,0xE9B59F37,0x547FF3F9,0x89E92A7C, + 0x1A9E5993,0xC7088016,0x7AC2ECD8,0xA754355D,0xDA273305,0x07B1EA80,0xBA7B864E,0x67ED5FCB, + 0x409D8AFE,0x9D0B537B,0x20C13FB5,0xFD57E630,0x8024E068,0x5DB239ED,0xE0785523,0x3DEE8CA6, + }, + + { + 0x00000000,0x9D0FE176,0xE16EC4AD,0x7C6125DB,0x19AC8F1B,0x84A36E6D,0xF8C24BB6,0x65CDAAC0, + 0x33591E36,0xAE56FF40,0xD237DA9B,0x4F383BED,0x2AF5912D,0xB7FA705B,0xCB9B5580,0x5694B4F6, + 0x66B23C6C,0xFBBDDD1A,0x87DCF8C1,0x1AD319B7,0x7F1EB377,0xE2115201,0x9E7077DA,0x037F96AC, + 0x55EB225A,0xC8E4C32C,0xB485E6F7,0x298A0781,0x4C47AD41,0xD1484C37,0xAD2969EC,0x3026889A, + 0xCD6478D8,0x506B99AE,0x2C0ABC75,0xB1055D03,0xD4C8F7C3,0x49C716B5,0x35A6336E,0xA8A9D218, + 0xFE3D66EE,0x63328798,0x1F53A243,0x825C4335,0xE791E9F5,0x7A9E0883,0x06FF2D58,0x9BF0CC2E, + 0xABD644B4,0x36D9A5C2,0x4AB88019,0xD7B7616F,0xB27ACBAF,0x2F752AD9,0x53140F02,0xCE1BEE74, + 0x988F5A82,0x0580BBF4,0x79E19E2F,0xE4EE7F59,0x8123D599,0x1C2C34EF,0x604D1134,0xFD42F042, + 0x41B9F7F1,0xDCB61687,0xA0D7335C,0x3DD8D22A,0x581578EA,0xC51A999C,0xB97BBC47,0x24745D31, + 0x72E0E9C7,0xEFEF08B1,0x938E2D6A,0x0E81CC1C,0x6B4C66DC,0xF64387AA,0x8A22A271,0x172D4307, + 0x270BCB9D,0xBA042AEB,0xC6650F30,0x5B6AEE46,0x3EA74486,0xA3A8A5F0,0xDFC9802B,0x42C6615D, + 0x1452D5AB,0x895D34DD,0xF53C1106,0x6833F070,0x0DFE5AB0,0x90F1BBC6,0xEC909E1D,0x719F7F6B, + 0x8CDD8F29,0x11D26E5F,0x6DB34B84,0xF0BCAAF2,0x95710032,0x087EE144,0x741FC49F,0xE91025E9, + 0xBF84911F,0x228B7069,0x5EEA55B2,0xC3E5B4C4,0xA6281E04,0x3B27FF72,0x4746DAA9,0xDA493BDF, + 0xEA6FB345,0x77605233,0x0B0177E8,0x960E969E,0xF3C33C5E,0x6ECCDD28,0x12ADF8F3,0x8FA21985, + 0xD936AD73,0x44394C05,0x385869DE,0xA55788A8,0xC09A2268,0x5D95C31E,0x21F4E6C5,0xBCFB07B3, + 0x8373EFE2,0x1E7C0E94,0x621D2B4F,0xFF12CA39,0x9ADF60F9,0x07D0818F,0x7BB1A454,0xE6BE4522, + 0xB02AF1D4,0x2D2510A2,0x51443579,0xCC4BD40F,0xA9867ECF,0x34899FB9,0x48E8BA62,0xD5E75B14, + 0xE5C1D38E,0x78CE32F8,0x04AF1723,0x99A0F655,0xFC6D5C95,0x6162BDE3,0x1D039838,0x800C794E, + 0xD698CDB8,0x4B972CCE,0x37F60915,0xAAF9E863,0xCF3442A3,0x523BA3D5,0x2E5A860E,0xB3556778, + 0x4E17973A,0xD318764C,0xAF795397,0x3276B2E1,0x57BB1821,0xCAB4F957,0xB6D5DC8C,0x2BDA3DFA, + 0x7D4E890C,0xE041687A,0x9C204DA1,0x012FACD7,0x64E20617,0xF9EDE761,0x858CC2BA,0x188323CC, + 0x28A5AB56,0xB5AA4A20,0xC9CB6FFB,0x54C48E8D,0x3109244D,0xAC06C53B,0xD067E0E0,0x4D680196, + 0x1BFCB560,0x86F35416,0xFA9271CD,0x679D90BB,0x02503A7B,0x9F5FDB0D,0xE33EFED6,0x7E311FA0, + 0xC2CA1813,0x5FC5F965,0x23A4DCBE,0xBEAB3DC8,0xDB669708,0x4669767E,0x3A0853A5,0xA707B2D3, + 0xF1930625,0x6C9CE753,0x10FDC288,0x8DF223FE,0xE83F893E,0x75306848,0x09514D93,0x945EACE5, + 0xA478247F,0x3977C509,0x4516E0D2,0xD81901A4,0xBDD4AB64,0x20DB4A12,0x5CBA6FC9,0xC1B58EBF, + 0x97213A49,0x0A2EDB3F,0x764FFEE4,0xEB401F92,0x8E8DB552,0x13825424,0x6FE371FF,0xF2EC9089, + 0x0FAE60CB,0x92A181BD,0xEEC0A466,0x73CF4510,0x1602EFD0,0x8B0D0EA6,0xF76C2B7D,0x6A63CA0B, + 0x3CF77EFD,0xA1F89F8B,0xDD99BA50,0x40965B26,0x255BF1E6,0xB8541090,0xC435354B,0x593AD43D, + 0x691C5CA7,0xF413BDD1,0x8872980A,0x157D797C,0x70B0D3BC,0xEDBF32CA,0x91DE1711,0x0CD1F667, + 0x5A454291,0xC74AA3E7,0xBB2B863C,0x2624674A,0x43E9CD8A,0xDEE62CFC,0xA2870927,0x3F88E851, + }, + + { + 0x00000000,0xB9FBDBE8,0xA886B191,0x117D6A79,0x8A7C6563,0x3387BE8B,0x22FAD4F2,0x9B010F1A, + 0xCF89CC87,0x7672176F,0x670F7D16,0xDEF4A6FE,0x45F5A9E4,0xFC0E720C,0xED731875,0x5488C39D, + 0x44629F4F,0xFD9944A7,0xECE42EDE,0x551FF536,0xCE1EFA2C,0x77E521C4,0x66984BBD,0xDF639055, + 0x8BEB53C8,0x32108820,0x236DE259,0x9A9639B1,0x019736AB,0xB86CED43,0xA911873A,0x10EA5CD2, + 0x88C53E9E,0x313EE576,0x20438F0F,0x99B854E7,0x02B95BFD,0xBB428015,0xAA3FEA6C,0x13C43184, + 0x474CF219,0xFEB729F1,0xEFCA4388,0x56319860,0xCD30977A,0x74CB4C92,0x65B626EB,0xDC4DFD03, + 0xCCA7A1D1,0x755C7A39,0x64211040,0xDDDACBA8,0x46DBC4B2,0xFF201F5A,0xEE5D7523,0x57A6AECB, + 0x032E6D56,0xBAD5B6BE,0xABA8DCC7,0x1253072F,0x89520835,0x30A9D3DD,0x21D4B9A4,0x982F624C, + 0xCAFB7B7D,0x7300A095,0x627DCAEC,0xDB861104,0x40871E1E,0xF97CC5F6,0xE801AF8F,0x51FA7467, + 0x0572B7FA,0xBC896C12,0xADF4066B,0x140FDD83,0x8F0ED299,0x36F50971,0x27886308,0x9E73B8E0, + 0x8E99E432,0x37623FDA,0x261F55A3,0x9FE48E4B,0x04E58151,0xBD1E5AB9,0xAC6330C0,0x1598EB28, + 0x411028B5,0xF8EBF35D,0xE9969924,0x506D42CC,0xCB6C4DD6,0x7297963E,0x63EAFC47,0xDA1127AF, + 0x423E45E3,0xFBC59E0B,0xEAB8F472,0x53432F9A,0xC8422080,0x71B9FB68,0x60C49111,0xD93F4AF9, + 0x8DB78964,0x344C528C,0x253138F5,0x9CCAE31D,0x07CBEC07,0xBE3037EF,0xAF4D5D96,0x16B6867E, + 0x065CDAAC,0xBFA70144,0xAEDA6B3D,0x1721B0D5,0x8C20BFCF,0x35DB6427,0x24A60E5E,0x9D5DD5B6, + 0xC9D5162B,0x702ECDC3,0x6153A7BA,0xD8A87C52,0x43A97348,0xFA52A8A0,0xEB2FC2D9,0x52D41931, + 0x4E87F0BB,0xF77C2B53,0xE601412A,0x5FFA9AC2,0xC4FB95D8,0x7D004E30,0x6C7D2449,0xD586FFA1, + 0x810E3C3C,0x38F5E7D4,0x29888DAD,0x90735645,0x0B72595F,0xB28982B7,0xA3F4E8CE,0x1A0F3326, + 0x0AE56FF4,0xB31EB41C,0xA263DE65,0x1B98058D,0x80990A97,0x3962D17F,0x281FBB06,0x91E460EE, + 0xC56CA373,0x7C97789B,0x6DEA12E2,0xD411C90A,0x4F10C610,0xF6EB1DF8,0xE7967781,0x5E6DAC69, + 0xC642CE25,0x7FB915CD,0x6EC47FB4,0xD73FA45C,0x4C3EAB46,0xF5C570AE,0xE4B81AD7,0x5D43C13F, + 0x09CB02A2,0xB030D94A,0xA14DB333,0x18B668DB,0x83B767C1,0x3A4CBC29,0x2B31D650,0x92CA0DB8, + 0x8220516A,0x3BDB8A82,0x2AA6E0FB,0x935D3B13,0x085C3409,0xB1A7EFE1,0xA0DA8598,0x19215E70, + 0x4DA99DED,0xF4524605,0xE52F2C7C,0x5CD4F794,0xC7D5F88E,0x7E2E2366,0x6F53491F,0xD6A892F7, + 0x847C8BC6,0x3D87502E,0x2CFA3A57,0x9501E1BF,0x0E00EEA5,0xB7FB354D,0xA6865F34,0x1F7D84DC, + 0x4BF54741,0xF20E9CA9,0xE373F6D0,0x5A882D38,0xC1892222,0x7872F9CA,0x690F93B3,0xD0F4485B, + 0xC01E1489,0x79E5CF61,0x6898A518,0xD1637EF0,0x4A6271EA,0xF399AA02,0xE2E4C07B,0x5B1F1B93, + 0x0F97D80E,0xB66C03E6,0xA711699F,0x1EEAB277,0x85EBBD6D,0x3C106685,0x2D6D0CFC,0x9496D714, + 0x0CB9B558,0xB5426EB0,0xA43F04C9,0x1DC4DF21,0x86C5D03B,0x3F3E0BD3,0x2E4361AA,0x97B8BA42, + 0xC33079DF,0x7ACBA237,0x6BB6C84E,0xD24D13A6,0x494C1CBC,0xF0B7C754,0xE1CAAD2D,0x583176C5, + 0x48DB2A17,0xF120F1FF,0xE05D9B86,0x59A6406E,0xC2A74F74,0x7B5C949C,0x6A21FEE5,0xD3DA250D, + 0x8752E690,0x3EA93D78,0x2FD45701,0x962F8CE9,0x0D2E83F3,0xB4D5581B,0xA5A83262,0x1C53E98A, + }, + + { + 0x00000000,0xAE689191,0x87A02563,0x29C8B4F2,0xD4314C87,0x7A59DD16,0x539169E4,0xFDF9F875, + 0x73139F4F,0xDD7B0EDE,0xF4B3BA2C,0x5ADB2BBD,0xA722D3C8,0x094A4259,0x2082F6AB,0x8EEA673A, + 0xE6273E9E,0x484FAF0F,0x61871BFD,0xCFEF8A6C,0x32167219,0x9C7EE388,0xB5B6577A,0x1BDEC6EB, + 0x9534A1D1,0x3B5C3040,0x129484B2,0xBCFC1523,0x4105ED56,0xEF6D7CC7,0xC6A5C835,0x68CD59A4, + 0x173F7B7D,0xB957EAEC,0x909F5E1E,0x3EF7CF8F,0xC30E37FA,0x6D66A66B,0x44AE1299,0xEAC68308, + 0x642CE432,0xCA4475A3,0xE38CC151,0x4DE450C0,0xB01DA8B5,0x1E753924,0x37BD8DD6,0x99D51C47, + 0xF11845E3,0x5F70D472,0x76B86080,0xD8D0F111,0x25290964,0x8B4198F5,0xA2892C07,0x0CE1BD96, + 0x820BDAAC,0x2C634B3D,0x05ABFFCF,0xABC36E5E,0x563A962B,0xF85207BA,0xD19AB348,0x7FF222D9, + 0x2E7EF6FA,0x8016676B,0xA9DED399,0x07B64208,0xFA4FBA7D,0x54272BEC,0x7DEF9F1E,0xD3870E8F, + 0x5D6D69B5,0xF305F824,0xDACD4CD6,0x74A5DD47,0x895C2532,0x2734B4A3,0x0EFC0051,0xA09491C0, + 0xC859C864,0x663159F5,0x4FF9ED07,0xE1917C96,0x1C6884E3,0xB2001572,0x9BC8A180,0x35A03011, + 0xBB4A572B,0x1522C6BA,0x3CEA7248,0x9282E3D9,0x6F7B1BAC,0xC1138A3D,0xE8DB3ECF,0x46B3AF5E, + 0x39418D87,0x97291C16,0xBEE1A8E4,0x10893975,0xED70C100,0x43185091,0x6AD0E463,0xC4B875F2, + 0x4A5212C8,0xE43A8359,0xCDF237AB,0x639AA63A,0x9E635E4F,0x300BCFDE,0x19C37B2C,0xB7ABEABD, + 0xDF66B319,0x710E2288,0x58C6967A,0xF6AE07EB,0x0B57FF9E,0xA53F6E0F,0x8CF7DAFD,0x229F4B6C, + 0xAC752C56,0x021DBDC7,0x2BD50935,0x85BD98A4,0x784460D1,0xD62CF140,0xFFE445B2,0x518CD423, + 0x5CFDEDF4,0xF2957C65,0xDB5DC897,0x75355906,0x88CCA173,0x26A430E2,0x0F6C8410,0xA1041581, + 0x2FEE72BB,0x8186E32A,0xA84E57D8,0x0626C649,0xFBDF3E3C,0x55B7AFAD,0x7C7F1B5F,0xD2178ACE, + 0xBADAD36A,0x14B242FB,0x3D7AF609,0x93126798,0x6EEB9FED,0xC0830E7C,0xE94BBA8E,0x47232B1F, + 0xC9C94C25,0x67A1DDB4,0x4E696946,0xE001F8D7,0x1DF800A2,0xB3909133,0x9A5825C1,0x3430B450, + 0x4BC29689,0xE5AA0718,0xCC62B3EA,0x620A227B,0x9FF3DA0E,0x319B4B9F,0x1853FF6D,0xB63B6EFC, + 0x38D109C6,0x96B99857,0xBF712CA5,0x1119BD34,0xECE04541,0x4288D4D0,0x6B406022,0xC528F1B3, + 0xADE5A817,0x038D3986,0x2A458D74,0x842D1CE5,0x79D4E490,0xD7BC7501,0xFE74C1F3,0x501C5062, + 0xDEF63758,0x709EA6C9,0x5956123B,0xF73E83AA,0x0AC77BDF,0xA4AFEA4E,0x8D675EBC,0x230FCF2D, + 0x72831B0E,0xDCEB8A9F,0xF5233E6D,0x5B4BAFFC,0xA6B25789,0x08DAC618,0x211272EA,0x8F7AE37B, + 0x01908441,0xAFF815D0,0x8630A122,0x285830B3,0xD5A1C8C6,0x7BC95957,0x5201EDA5,0xFC697C34, + 0x94A42590,0x3ACCB401,0x130400F3,0xBD6C9162,0x40956917,0xEEFDF886,0xC7354C74,0x695DDDE5, + 0xE7B7BADF,0x49DF2B4E,0x60179FBC,0xCE7F0E2D,0x3386F658,0x9DEE67C9,0xB426D33B,0x1A4E42AA, + 0x65BC6073,0xCBD4F1E2,0xE21C4510,0x4C74D481,0xB18D2CF4,0x1FE5BD65,0x362D0997,0x98459806, + 0x16AFFF3C,0xB8C76EAD,0x910FDA5F,0x3F674BCE,0xC29EB3BB,0x6CF6222A,0x453E96D8,0xEB560749, + 0x839B5EED,0x2DF3CF7C,0x043B7B8E,0xAA53EA1F,0x57AA126A,0xF9C283FB,0xD00A3709,0x7E62A698, + 0xF088C1A2,0x5EE05033,0x7728E4C1,0xD9407550,0x24B98D25,0x8AD11CB4,0xA319A846,0x0D7139D7, + } +#endif // CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +}; +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/file_adapter.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/file_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..23307abc904d0cd3235cd36559ce4b4c71916ee6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/file_adapter.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +#include "caffe2/serialize/istream_adapter.h" +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +class TORCH_API FileAdapter final : public ReadAdapterInterface { + public: + C10_DISABLE_COPY_AND_ASSIGN(FileAdapter); + explicit FileAdapter(const std::string& file_name); + size_t size() const override; + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override; + ~FileAdapter() override; + + private: + // An RAII Wrapper for a FILE pointer. Closes on destruction. + struct RAIIFile { + FILE* fp_; + explicit RAIIFile(const std::string& file_name); + ~RAIIFile(); + }; + + RAIIFile file_; + // The size of the opened file in bytes + uint64_t size_; +}; + +} // namespace serialize +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/in_memory_adapter.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/in_memory_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..f9f5619ac3bc54205c79426d02e4cdf683eda09a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/in_memory_adapter.h @@ -0,0 +1,30 @@ +#pragma once +#include +#include + +namespace caffe2 { +namespace serialize { + +class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { + public: + explicit MemoryReadAdapter(const void* data, off_t size) + : data_(data), size_(size) {} + + size_t size() const override { + return size_; + } + + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override { + (void)what; + memcpy(buf, (int8_t*)(data_) + pos, n); + return n; + } + + private: + const void* data_; + off_t size_; +}; + +} // namespace serialize +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/inline_container.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/inline_container.h new file mode 100644 index 0000000000000000000000000000000000000000..47bd7886dc93e633a7b73b9f498e1a0ded5d5244 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/inline_container.h @@ -0,0 +1,310 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "caffe2/serialize/istream_adapter.h" +#include "caffe2/serialize/read_adapter_interface.h" +#include "caffe2/serialize/versions.h" + +extern "C" { +typedef struct mz_zip_archive mz_zip_archive; +} + +// PyTorch containers are a special zip archive with the following layout +// archive_name.zip contains: +// archive_name/ +// version # a file with a single decimal number written in ascii, +// # used to establish the version of the archive format +// model.json # overall model description, this is a json output of +// # ModelDef from torch.proto +// # the following names are by convention only, model.json will +// # refer to these files by full names +// tensors/ +// 0 # flat storage for tensor data, meta-data about shapes, etc. is +// # in model.json +// 1 +// ... +// # code entries will only exist for modules that have methods attached +// code/ +// archive_name.py # serialized torch script code (python syntax, using +// PythonPrint) archive_name_my_submodule.py # submodules have separate +// files +// +// The PyTorchStreamWriter also ensures additional useful properties for these +// files +// 1. All files are stored uncompressed. +// 2. All files in the archive are aligned to 64 byte boundaries such that +// it is possible to mmap the entire file and get an aligned pointer to +// tensor data. +// 3. We universally write in ZIP64 format for consistency. + +// The PyTorchStreamReader also provides additional properties: +// 1. It can read zip files that are created with common +// zip tools. This means that even though our writer doesn't compress files, +// the reader can still read files that were compressed. +// 2. It provides a getRecordOffset function which returns the offset into the +// raw file where file data lives. If the file was written with +// PyTorchStreamWriter it is guaranteed to be 64 byte aligned. + +// PyTorchReader/Writer handle checking the version number on the archive format +// and ensure that all files are written to a archive_name directory so they +// unzip cleanly. + +// When developing this format we want to pay particular attention to the +// following use cases: +// +// -- Reading -- +// 1) Reading with full random access +// a) Reading with file api's such as fread() +// b) mmaping the file and jumping around the mapped region +// 2) Reading with 1-pass sequential access +// -> A reader will need to build up a data structure of parsed structures +// as it reads +// +// -- Writing -- +// 1) Writing with full random access +// 2) Writing with 1-pass sequential access +// -> We must take care not to require updating values that have already +// been written. We place the variable-length index at the end and do +// not put any indicies into the header to fulfill this constraint. + +// The model.json, which contains all the metadata information, +// should be written as the last file. One reason is that the size of tensor +// data is usually stable. As long as the shape and type of the tensor do not +// change, the size of the data won't change. On the other sied, the size of the +// serialized model is likely to change, so we store it as the last record, and +// we don't need to move previous records when updating the model data. + +// The zip format is sufficiently flexible to handle the above use-case. +// it puts its central directory at the end of the archive and we write +// model.json as the last file when writing after we have accumulated all +// other information. + +namespace caffe2 { +namespace serialize { + +static constexpr const char* kSerializationIdRecordName = + ".data/serialization_id"; + +struct MzZipReaderIterWrapper; + +class TORCH_API ChunkRecordIterator { + public: + ~ChunkRecordIterator(); + + // Read at most `chunkSize` into `buf`. Return the number of actual bytes + // read. + size_t next(void* buf); + size_t recordSize() const { + return recordSize_; + } + + private: + ChunkRecordIterator( + size_t recordSize, + size_t chunkSize, + std::unique_ptr iter); + + const size_t recordSize_; + const size_t chunkSize_; + size_t offset_; + std::unique_ptr iter_; + + friend class PyTorchStreamReader; +}; + +class TORCH_API PyTorchStreamReader final { + public: + explicit PyTorchStreamReader(const std::string& file_name); + explicit PyTorchStreamReader(std::istream* in); + explicit PyTorchStreamReader(std::shared_ptr in); + + // return dataptr, size + // set allocator to override default cpu allocator + std::tuple getRecord( + const std::string& name, + std::optional allocator = std::nullopt); + // multi-thread getRecord + std::tuple getRecord( + const std::string& name, + std::vector>& additionalReaders, + std::optional allocator = std::nullopt); + // inplace memory writing + size_t getRecord(const std::string& name, void* dst, size_t n); + // inplace memory writing, multi-threads. + // When additionalReaders is empty, the default behavior is call + // getRecord(name, dst, n) with default reader This approach can be used for + // reading large tensors. + size_t getRecord( + const std::string& name, + void* dst, + size_t n, + std::vector>& additionalReaders); + size_t getRecord( + const std::string& name, + void* dst, + size_t n, + size_t chunk_size, + void* buf, + const std::function& memcpy_func = + nullptr); + + // Concurrent reading records with multiple readers. + // additionalReaders are additional clients to access the underlying record at + // different offsets and write to different trunks of buffers. If the overall + // size of the tensor is 10, and size of additionalReader is 2. The default + // thread will read [0,4), the additional reader will read [4,8). The default + // reader will read [8,10). The default reader will write to buffer[0,4), the + // additional reader will write to buffer[4,8), the additional reader will + // write to buffer[8,10). When additionalReaders is empty, the default + // behavior is call getRecord(name) with default reader This approach can be + // used for reading large tensors. + size_t getRecordMultiReaders( + const std::string& name, + std::vector>& additionalReaders, + void* dst, + size_t n); + + size_t getRecordSize(const std::string& name); + size_t getRecordHeaderOffset(const std::string& name); + size_t getRecordOffset(const std::string& name); + size_t getRecordOffsetNoRead( + size_t cursor, + std::string filename, + size_t size, + uint64_t alignment); + bool hasRecord(const std::string& name); + std::vector getAllRecords(); + + ChunkRecordIterator createChunkReaderIter( + const std::string& name, + const size_t recordSize, + const size_t chunkSize); + + ~PyTorchStreamReader(); + uint64_t version() const { + return version_; + } + const std::string& serializationId() { + return serialization_id_; + } + + void setShouldLoadDebugSymbol(bool should_load_debug_symbol) { + load_debug_symbol_ = should_load_debug_symbol; + } + void setAdditionalReaderSizeThreshold(const size_t& size) { + additional_reader_size_threshold_ = size; + } + + private: + void init(); + size_t read(uint64_t pos, char* buf, size_t n); + void valid(const char* what, const char* info = ""); + size_t getRecordID(const std::string& name); + + friend size_t + istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n); + std::unique_ptr ar_; + std::string archive_name_; + std::string archive_name_plus_slash_; + std::shared_ptr in_; + int64_t version_; + std::mutex reader_lock_; + bool load_debug_symbol_ = true; + std::string serialization_id_; + size_t additional_reader_size_threshold_; +}; + +class TORCH_API PyTorchStreamWriter final { + public: + explicit PyTorchStreamWriter( + const std::string& archive_name, + bool compute_crc32 = true, + uint64_t alignment = 64); + explicit PyTorchStreamWriter( + const std::function writer_func, + bool compute_crc32 = true, + uint64_t alignment = 64); + + void setMinVersion(const uint64_t version); + + void writeRecord( + const std::string& name, + const void* data, + size_t size, + bool compress = false); + void writeEndOfFile(); + + const std::unordered_set& getAllWrittenRecords(); + + bool finalized() const { + return finalized_; + } + + const std::string& archiveName() { + return archive_name_; + } + + const std::string& serializationId() { + return serialization_id_; + } + + ~PyTorchStreamWriter(); + + private: + void setup(const std::string& file_name); + void valid(const char* what, const char* info = ""); + void writeSerializationId(); + size_t current_pos_ = 0; + std::unordered_set files_written_; + std::unique_ptr ar_; + std::string archive_name_; + std::string archive_name_plus_slash_; + std::string padding_; + std::ofstream file_stream_; + std::function writer_func_; + uint64_t combined_uncomp_crc32_ = 0; + std::string serialization_id_; + bool compute_crc32_; + uint64_t alignment_; + + // This number will be updated when the model has operators + // that have valid upgraders. + uint64_t version_ = kMinProducedFileFormatVersion; + bool finalized_ = false; + bool err_seen_ = false; + friend size_t ostream_write_func( + void* pOpaque, + uint64_t file_ofs, + const void* pBuf, + size_t n); +}; + +namespace detail { + +// Returns a record to be appended to the local user extra data entry in order +// to make data beginning aligned at kFieldAlignment bytes boundary. +size_t getPadding( + size_t cursor, + size_t filename_size, + size_t size, + std::string& padding_buf, + uint64_t alignment); + +std::tuple +getOffset(size_t cursor, size_t filename_size, size_t size, uint64_t alignment); + +} // namespace detail + +} // namespace serialize +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/istream_adapter.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/istream_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..680c288a15f2e89d5a3b8b51d8aaddf44f727d19 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/istream_adapter.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "c10/macros/Macros.h" +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +// this is a reader implemented by std::istream +class TORCH_API IStreamAdapter final : public ReadAdapterInterface { + public: + C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter); + explicit IStreamAdapter(std::istream* istream); + size_t size() const override; + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override; + ~IStreamAdapter() override; + + private: + std::istream* istream_; + void validate(const char* what) const; +}; + +} // namespace serialize +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/read_adapter_interface.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/read_adapter_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..0a6b5b74a762e7c90b64a6f93717802f658dd164 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/read_adapter_interface.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +#include "c10/macros/Macros.h" + +namespace caffe2 { +namespace serialize { + +// this is the interface for the (file/stream/memory) reader in +// PyTorchStreamReader. with this interface, we can extend the support +// besides standard istream +class TORCH_API ReadAdapterInterface { + public: + virtual size_t size() const = 0; + virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const = 0; + virtual ~ReadAdapterInterface(); +}; + +} // namespace serialize +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/versions.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/versions.h new file mode 100644 index 0000000000000000000000000000000000000000..6e2c27adc8fae8fdb07dcd9dafe72c503c243c2b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/serialize/versions.h @@ -0,0 +1,133 @@ +#pragma once +#include + +namespace caffe2 { +namespace serialize { + +constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; + +constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL; + +// Versions (i.e. why was the version number bumped?) + +// Note [Dynamic Versions and torch.jit.save vs. torch.save] +// +// Our versioning scheme has a "produced file format version" which +// describes how an archive is to be read. The version written in an archive +// is at least this current produced file format version, but may be greater +// if it includes certain symbols. We refer to these conditional versions +// as "dynamic," since they are identified at runtime. +// +// Dynamic versioning is useful when an operator's semantics are updated. +// When using torch.jit.save we want those semantics to be preserved. If +// we bumped the produced file format version on every change, however, +// then older versions of PyTorch couldn't read even simple archives, like +// a single tensor, from newer versions of PyTorch. Instead, we +// assign dynamic versions to these changes that override the +// produced file format version as needed. That is, when the semantics +// of torch.div changed it was assigned dynamic version 4, and when +// torch.jit.saving modules that use torch.div those archives also have +// (at least) version 4. This prevents earlier versions of PyTorch +// from accidentally performing the wrong kind of division. Modules +// that don't use torch.div or other operators with dynamic versions +// can write the produced file format version, and these programs will +// run as expected on earlier versions of PyTorch. +// +// While torch.jit.save attempts to preserve operator semantics, +// torch.save does not. torch.save is analogous to pickling Python, so +// a function that uses torch.div will have different behavior if torch.saved +// and torch.loaded across PyTorch versions. From a technical perspective, +// torch.save ignores dynamic versioning. + +// 1. Initial version +// 2. Removed op_version_set version numbers +// 3. Added type tags to pickle serialization of container types +// 4. (Dynamic) Stopped integer division using torch.div +// (a versioned symbol preserves the historic behavior of versions 1--3) +// 5. (Dynamic) Stops torch.full inferring a floating point dtype +// when given bool or integer fill values. +// 6. Write version string to `./data/version` instead of `version`. + +// [12/15/2021] +// kProducedFileFormatVersion is set to 7 from 3 due to a different +// interpretation of what file format version is. +// Whenever there is new upgrader introduced, +// this number should be bumped. +// The reasons that version is bumped in the past: +// 1. aten::div is changed at version 4 +// 2. aten::full is changed at version 5 +// 3. torch.package uses version 6 +// 4. Introduce new upgrader design and set the version number to 7 +// mark this change +// -------------------------------------------------- +// We describe new operator version bump reasons here: +// 1) [01/24/2022] +// We bump the version number to 8 to update aten::linspace +// and aten::linspace.out to error out when steps is not +// provided. (see: https://github.com/pytorch/pytorch/issues/55951) +// 2) [01/30/2022] +// Bump the version number to 9 to update aten::logspace and +// and aten::logspace.out to error out when steps is not +// provided. (see: https://github.com/pytorch/pytorch/issues/55951) +// 3) [02/11/2022] +// Bump the version number to 10 to update aten::gelu and +// and aten::gelu.out to support the new approximate kwarg. +// (see: https://github.com/pytorch/pytorch/pull/61439) +constexpr uint64_t kProducedFileFormatVersion = 0xAL; + +// Absolute minimum version we will write packages. This +// means that every package from now on will always be +// greater than this number. +constexpr uint64_t kMinProducedFileFormatVersion = 0x3L; + +// The version we write when the archive contains bytecode. +// It must be higher or eq to kProducedFileFormatVersion. +// Because torchscript changes is likely introduce bytecode change. +// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion +// should be increased too. The relationship is: +// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion +// >= kProducedFileFormatVersion +// If a format change is forward compatible (still readable by older +// executables), we will not increment the version number, to minimize the +// risk of breaking existing clients. TODO: A better way would be to allow +// the caller that creates a model to specify a maximum version that its +// clients can accept. +// Versions: +// 0x1L: Initial version +// 0x2L: (Comment missing) +// 0x3L: (Comment missing) +// 0x4L: (update) Added schema to function tuple. Forward-compatible change. +// 0x5L: (update) Update bytecode is sharing constant tensor files from +// torchscript, and only serialize extra tensors that are not in the +// torchscript constant table. Also update tensor storage schema adapting to +// the unify format, the root key of tensor storage is updated from {index} to +// {the_pointer_value_the_tensor.storage}, for example: +// `140245072983168.storage` Forward-compatibility change. +// 0x6L: Implicit opereator versioning using number of specified argument. +// Refer to the summary of https://github.com/pytorch/pytorch/pull/56845 for +// details. +// 0x7L: Enable support for operators with default arguments plus out +// arguments. Refer. See https://github.com/pytorch/pytorch/pull/63651 for +// details. +// 0x8L: Emit promoted operators as instructions. See +// https://github.com/pytorch/pytorch/pull/71662 for details. +// 0x9L: Change serialization format from pickle to format This version is to +// serve migration. v8 pickle and v9 flatbuffer are the same. Refer to the +// summary of https://github.com/pytorch/pytorch/pull/75201 for more details. +constexpr uint64_t kProducedBytecodeVersion = 0x8L; + +// static_assert( +// kProducedBytecodeVersion >= kProducedFileFormatVersion, +// "kProducedBytecodeVersion must be higher or equal to +// kProducedFileFormatVersion."); + +// Introduce kMinSupportedBytecodeVersion and kMaxSupportedBytecodeVersion +// for limited backward/forward compatibility support of bytecode. If +// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion +// (in loader), we should support this model_version. For example, we provide a +// wrapper to handle an updated operator. +constexpr uint64_t kMinSupportedBytecodeVersion = 0x4L; +constexpr uint64_t kMaxSupportedBytecodeVersion = 0x9L; + +} // namespace serialize +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/fixed_divisor.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/fixed_divisor.h new file mode 100644 index 0000000000000000000000000000000000000000..c8607c12c76dff430c6785267dae768a79ba8f4e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/fixed_divisor.h @@ -0,0 +1,132 @@ +#ifndef CAFFE2_UTILS_FIXED_DIVISOR_H_ +#define CAFFE2_UTILS_FIXED_DIVISOR_H_ + +#include +#include +#include + +// See Note [hip-clang differences to hcc] + +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) || defined(__HIP__) || \ + (defined(__clang__) && defined(__CUDA__)) +#define FIXED_DIVISOR_DECL inline __host__ __device__ +#else +#define FIXED_DIVISOR_DECL inline +#endif + +namespace caffe2 { + +// Utility class for quickly calculating quotients and remainders for +// a known integer divisor +template +class FixedDivisor {}; + +// Works for any positive divisor, 1 to INT_MAX. One 64-bit +// multiplication and one 64-bit shift is used to calculate the +// result. +template <> +class FixedDivisor { + public: + FixedDivisor() = default; + + explicit FixedDivisor(const std::int32_t d) : d_(d) { +#if !defined(USE_ROCM) + CalcSignedMagic(); +#endif // USE_ROCM + } + + FIXED_DIVISOR_DECL std::int32_t d() const { + return d_; + } + +#if !defined(USE_ROCM) + FIXED_DIVISOR_DECL std::uint64_t magic() const { + return magic_; + } + + FIXED_DIVISOR_DECL int shift() const { + return shift_; + } +#endif // USE_ROCM + + /// Calculates `q = n / d`. + FIXED_DIVISOR_DECL std::int32_t Div(const std::int32_t n) const { +#if defined(USE_ROCM) + return n / d_; +#else // USE_ROCM + // In lieu of a mulhi instruction being available, perform the + // work in uint64 + return (int32_t)((magic_ * (uint64_t)n) >> shift_); +#endif // USE_ROCM + } + + /// Calculates `r = n % d`. + FIXED_DIVISOR_DECL std::int32_t Mod(const std::int32_t n) const { + return n - d_ * Div(n); + } + + /// Calculates `q = n / d` and `r = n % d` together. + FIXED_DIVISOR_DECL void + DivMod(const std::int32_t n, std::int32_t* q, int32_t* r) const { + *q = Div(n); + *r = n - d_ * *q; + } + + private: +#if !defined(USE_ROCM) + // Calculates magic multiplicative value and shift amount for calculating `q = + // n / d` for signed 32-bit integers. + // Implementation taken from Hacker's Delight section 10. + void CalcSignedMagic() { + if (d_ == 1) { + magic_ = UINT64_C(0x1) << 32; + shift_ = 32; + return; + } + + const std::uint32_t two31 = UINT32_C(0x80000000); + const std::uint32_t ad = std::abs(d_); + const std::uint32_t t = two31 + ((uint32_t)d_ >> 31); + const std::uint32_t anc = t - 1 - t % ad; // Absolute value of nc. + std::uint32_t p = 31; // Init. p. + std::uint32_t q1 = two31 / anc; // Init. q1 = 2**p/|nc|. + std::uint32_t r1 = two31 - q1 * anc; // Init. r1 = rem(2**p, |nc|). + std::uint32_t q2 = two31 / ad; // Init. q2 = 2**p/|d|. + std::uint32_t r2 = two31 - q2 * ad; // Init. r2 = rem(2**p, |d|). + std::uint32_t delta = 0; + do { + ++p; + q1 <<= 1; // Update q1 = 2**p/|nc|. + r1 <<= 1; // Update r1 = rem(2**p, |nc|). + if (r1 >= anc) { // (Must be an unsigned + ++q1; // comparison here). + r1 -= anc; + } + q2 <<= 1; // Update q2 = 2**p/|d|. + r2 <<= 1; // Update r2 = rem(2**p, |d|). + if (r2 >= ad) { // (Must be an unsigned + ++q2; // comparison here). + r2 -= ad; + } + delta = ad - r2; + } while (q1 < delta || (q1 == delta && r1 == 0)); + std::int32_t magic = q2 + 1; + if (d_ < 0) { + magic = -magic; + } + shift_ = p; + magic_ = (std::uint64_t)(std::uint32_t)magic; + } +#endif // USE_ROCM + + std::int32_t d_ = 1; + +#if !defined(USE_ROCM) + std::uint64_t magic_; + int shift_; +#endif // USE_ROCM +}; + +} // namespace caffe2 + +#endif // CAFFE2_UTILS_FIXED_DIVISOR_H_ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/proto_wrap.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/proto_wrap.h new file mode 100644 index 0000000000000000000000000000000000000000..4ce0ce5d3b8bfb40d3943354730d484ad7ae562a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/proto_wrap.h @@ -0,0 +1,37 @@ +#ifndef CAFFE2_UTILS_PROTO_WRAP_H_ +#define CAFFE2_UTILS_PROTO_WRAP_H_ + +#include + +namespace caffe2 { + +// A wrapper function to shut down protobuf library (this is needed in ASAN +// testing and valgrind cases to avoid protobuf appearing to "leak" memory). +TORCH_API void ShutdownProtobufLibrary(); + +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); +} // namespace caffe2 + +namespace ONNX_NAMESPACE { + +// ONNX wrapper functions for protobuf's GetEmptyStringAlreadyInited() function +// used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +} // namespace ONNX_NAMESPACE + +namespace torch { + +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +void ShutdownProtobufLibrary(); + +} // namespace torch +#endif // CAFFE2_UTILS_PROTO_WRAP_H_ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/string_utils.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/string_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6bf3b867b5434eb71d7f972fbd5db41eafac5957 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/string_utils.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace caffe2 { + +TORCH_API std::vector +split(char separator, const std::string& string, bool ignore_empty = false); + +TORCH_API std::string trim(const std::string& str); + +TORCH_API size_t editDistance( + const std::string& s1, + const std::string& s2, + size_t max_distance = 0); + +TORCH_API inline bool StartsWith( + const std::string& str, + const std::string& prefix) { + return str.length() >= prefix.length() && + std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == + prefix.end(); +} + +TORCH_API inline bool EndsWith( + const std::string& full, + const std::string& ending) { + if (full.length() >= ending.length()) { + return ( + 0 == + full.compare(full.length() - ending.length(), ending.length(), ending)); + } else { + return false; + } +} + +TORCH_API int32_t editDistanceHelper( + const char* s1, + size_t s1_len, + const char* s2, + size_t s2_len, + std::vector& current, + std::vector& previous, + std::vector& previous1, + size_t max_distance); +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPool.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPool.h new file mode 100644 index 0000000000000000000000000000000000000000..48eb7610eb6cf0d6e717a703bf62d7259a56b5da --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPool.h @@ -0,0 +1,79 @@ +#ifndef CAFFE2_UTILS_THREADPOOL_H_ +#define CAFFE2_UTILS_THREADPOOL_H_ + +#include "ThreadPoolCommon.h" + +#include +#include +#include +#include +#include + +#include "c10/util/Flags.h" +#include "caffe2/core/common.h" + +// +// A work-stealing threadpool loosely based off of pthreadpool +// + +namespace caffe2 { + +struct Task; +class WorkersPool; + +constexpr size_t kCacheLineSize = 64; + +// A threadpool with the given number of threads. +// NOTE: the kCacheLineSize alignment is present only for cache +// performance, and is not strictly enforced (for example, when +// the object is created on the heap). Thus, in order to avoid +// misaligned intrinsics, no SSE instructions shall be involved in +// the ThreadPool implementation. +// Note: alignas is disabled because some compilers do not deal with +// TORCH_API and alignas annotations at the same time. +class TORCH_API /*alignas(kCacheLineSize)*/ ThreadPool { + public: + static ThreadPool* createThreadPool(int numThreads); + static std::unique_ptr defaultThreadPool(); + virtual ~ThreadPool() = default; + // Returns the number of threads currently in use + virtual int getNumThreads() const = 0; + virtual void setNumThreads(size_t numThreads) = 0; + + // Sets the minimum work size (range) for which to invoke the + // threadpool; work sizes smaller than this will just be run on the + // main (calling) thread + void setMinWorkSize(size_t size) { + std::lock_guard guard(executionMutex_); + minWorkSize_ = size; + } + + size_t getMinWorkSize() const { + return minWorkSize_; + } + virtual void run(const std::function& fn, size_t range) = 0; + + // Run an arbitrary function in a thread-safe manner accessing the Workers + // Pool + virtual void withPool(const std::function& fn) = 0; + + protected: + static size_t defaultNumThreads_; + mutable std::mutex executionMutex_; + size_t minWorkSize_; +}; + +size_t getDefaultNumThreads(); +} // namespace caffe2 + +C10_DECLARE_bool(caffe2_threadpool_force_inline); + +// Whether or not threadpool caps apply to Android +C10_DECLARE_int(caffe2_threadpool_android_cap); + +// Whether or not threadpool caps apply to iOS and MacOS +C10_DECLARE_int(caffe2_threadpool_ios_cap); +C10_DECLARE_int(caffe2_threadpool_macos_cap); + +C10_DECLARE_int(pthreadpool_size); +#endif // CAFFE2_UTILS_THREADPOOL_H_ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPoolCommon.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPoolCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..49f3aab2d86853714678bbc27e93ebef1cfdc602 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPoolCommon.h @@ -0,0 +1,20 @@ +#ifndef CAFFE2_UTILS_THREADPOOL_COMMON_H_ +#define CAFFE2_UTILS_THREADPOOL_COMMON_H_ + +#ifdef __APPLE__ +#include +#endif + +// caffe2 depends upon NNPACK, which depends upon this threadpool, so +// unfortunately we can't reference core/common.h here + +// This is copied from core/common.h's definition of C10_MOBILE +// Define enabled when building for iOS or Android devices +#if defined(__ANDROID__) +#define C10_ANDROID 1 +#elif (defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define C10_IOS 1 +#endif // ANDROID / IOS + +#endif // CAFFE2_UTILS_THREADPOOL_COMMON_H_ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/WorkersPool.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/WorkersPool.h new file mode 100644 index 0000000000000000000000000000000000000000..5de6b1213e843b4425d121fa8f52e58f12a72d5c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/WorkersPool.h @@ -0,0 +1,375 @@ +#pragma once + +#include +#include +#include +#include "c10/util/thread_name.h" +#include +#include + +#if defined(_MSC_VER) +#include +#endif + +namespace caffe2 { + +// Uses code derived from gemmlowp, +// https://github.com/google/gemmlowp/blob/6c91e1ed0c2eff1182d804310b92911fe9c18019/internal/multi_thread_gemm.h +// Changes: +// - allocation-free execute() +// - Use RAII where possible. +// - Run the first task on the main thread (since that is the largest task). +// - removed custom allocator. +// - Removed some ifdef's +// - cache-line align Worker. +// - use std::atomic instead of volatile and custom barriers. +// - use std::mutex/std::condition_variable instead of raw pthreads. + +constexpr size_t kGEMMLOWPCacheLineSize = 64; + +template +struct AllocAligned { + // Allocate a T aligned at an `align` byte address + template + static T* alloc(Args&&... args) { + void* p = nullptr; + +#if defined(__ANDROID__) + p = memalign(kGEMMLOWPCacheLineSize, sizeof(T)); +#elif defined(_MSC_VER) + p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize); +#else + auto res = posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T)); + (void)res; +#endif + + if (p) { + return new (p) T(std::forward(args)...); + } + + return nullptr; + } + + // Free a T previously allocated via AllocAligned::alloc() + static void release(T* p) { + if (p) { + p->~T(); +#if defined(_MSC_VER) + _aligned_free((void*)p); +#else + free((void*)p); +#endif + } + } +}; + +// Deleter object for unique_ptr for an aligned object +template +struct AlignedDeleter { + void operator()(T* p) const { AllocAligned::release(p); } +}; + +// make_unique that guarantees alignment +template +struct MakeAligned { + template + static std::unique_ptr> make(Args&&... args) { + return std::unique_ptr>( + AllocAligned::alloc(std::forward(args)...)); + } +}; + +const int kMaxBusyWaitNOPs = 32 * 1000 * 1000; + +#if defined(_MSC_VER) +#define GEMMLOWP_NOP __nop(); +#else +#define GEMMLOWP_NOP "nop\n" +#endif + +#define GEMMLOWP_STRING_CONCAT_4(X) X X X X +#define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP) +#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4) +#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16) + +inline int Do256NOPs() { +#if defined(_MSC_VER) + GEMMLOWP_NOP64; +#else + asm volatile(GEMMLOWP_NOP64); +#endif + return 64; +} + +#undef GEMMLOWP_STRING_CONCAT_4 +#undef GEMMLOWP_NOP256 +#undef GEMMLOWP_NOP64 +#undef GEMMLOWP_NOP16 +#undef GEMMLOWP_NOP4 +#undef GEMMLOWP_NOP + +// Waits until *var != initial_value. +// +// Returns the new value of *var. The guarantee here is that +// the return value is different from initial_value, and that that +// new value has been taken by *var at some point during the +// execution of this function. There is no guarantee that this is +// still the value of *var when this function returns, since *var is +// not assumed to be guarded by any lock. +// +// First does some busy-waiting for a fixed number of no-op cycles, +// then falls back to passive waiting for the given condvar, guarded +// by the given mutex. +// +// The idea of doing some initial busy-waiting is to help get +// better and more consistent multithreading benefits for small GEMM sizes. +// Busy-waiting help ensuring that if we need to wake up soon after having +// started waiting, then we can wake up quickly (as opposed to, say, +// having to wait to be scheduled again by the OS). On the other hand, +// we must still eventually revert to passive waiting for longer waits +// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) +// so as to avoid permanently spinning. +// +template +T WaitForVariableChange(std::atomic* var, + T initial_value, + std::condition_variable* cond, + std::mutex* mutex) { + // If we are on a platform that supports it, spin for some time. + { + int nops = 0; + // First, trivial case where the variable already changed value. + T new_value = var->load(std::memory_order_relaxed); + if (new_value != initial_value) { + std::atomic_thread_fence(std::memory_order_acquire); + return new_value; + } + // Then try busy-waiting. + while (nops < kMaxBusyWaitNOPs) { + nops += Do256NOPs(); + new_value = var->load(std::memory_order_relaxed); + if (new_value != initial_value) { + std::atomic_thread_fence(std::memory_order_acquire); + return new_value; + } + } + } + + // Finally, do real passive waiting. + { + std::unique_lock g(*mutex); + T new_value = var->load(std::memory_order_relaxed); + // Handle spurious wakeups. + cond->wait(g, [&]() { + new_value = var->load(std::memory_order_relaxed); + return new_value != initial_value; + }); + TORCH_DCHECK_NE(static_cast(new_value), static_cast(initial_value)); + return new_value; + } +} + +// A BlockingCounter lets one thread to wait for N events to occur. +// This is how the master thread waits for all the worker threads +// to have finished working. +class BlockingCounter { + public: + // Sets/resets the counter; initial_count is the number of + // decrementing events that the Wait() call will be waiting for. + void Reset(std::size_t initial_count) { + std::lock_guard g(mutex_); + TORCH_DCHECK_EQ(count_, 0); + count_ = initial_count; + } + + // Decrements the counter; if the counter hits zero, signals + // the thread that was waiting for that, and returns true. + // Otherwise (if the decremented count is still nonzero), + // returns false. + bool DecrementCount() { + const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1; + if (count_value == 0) { + std::lock_guard g(mutex_); + cond_.notify_one(); + } + bool retval = count_value == 0; + return retval; + } + + // Waits for the N other threads (N having been set by Reset()) + // to hit the BlockingCounter. + void Wait() { + while (size_t count_value = count_.load(std::memory_order_relaxed)) { + WaitForVariableChange(&count_, count_value, &cond_, &mutex_); + } + } + + private: + std::condition_variable cond_; + std::mutex mutex_; + std::atomic count_{0}; +}; + +// A workload for a worker. +struct Task { + Task() = default; + virtual ~Task() = default; + virtual void Run() = 0; +}; + +// A worker thread. +class alignas(kGEMMLOWPCacheLineSize) Worker { + public: + enum class State : uint8_t { + ThreadStartup, // The initial state before the thread main loop runs. + Ready, // Is not working, has not yet received new work to do. + HasWork, // Has work to do. + ExitAsSoonAsPossible // Should exit at earliest convenience. + }; + + explicit Worker(BlockingCounter* counter_to_decrement_when_ready) + : task_(nullptr), + state_(State::ThreadStartup), + counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { + thread_ = std::make_unique([this]() { + c10::setThreadName("pt_thread_pool"); + this->ThreadFunc(); + }); + } + + ~Worker() { + ChangeState(State::ExitAsSoonAsPossible); + thread_->join(); + } + + // Changes State; may be called from either the worker thread + // or the master thread; however, not all state transitions are legal, + // which is guarded by assertions. + void ChangeState(State new_state) { + std::lock_guard g(state_mutex_); + DCHECK(new_state != state_.load(std::memory_order_relaxed)); + switch (state_.load(std::memory_order_relaxed)) { + case State::ThreadStartup: + DCHECK(new_state == State::Ready); + break; + case State::Ready: + DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible); + break; + case State::HasWork: + DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible); + break; + default: + abort(); + } + state_.store(new_state, std::memory_order_relaxed); + state_cond_.notify_one(); + if (new_state == State::Ready) { + counter_to_decrement_when_ready_->DecrementCount(); + } + } + + // Thread entry point. + void ThreadFunc() { + c10::setThreadName("CaffeWorkersPool"); + ChangeState(State::Ready); + + // Thread main loop + while (true) { + // Get a state to act on + // In the 'Ready' state, we have nothing to do but to wait until + // we switch to another state. + State state_to_act_upon = + WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_); + + // We now have a state to act on, so act. + switch (state_to_act_upon) { + case State::HasWork: + // Got work to do! So do it, and then revert to 'Ready' state. + DCHECK(task_.load()); + (*task_).Run(); + task_ = nullptr; + ChangeState(State::Ready); + break; + case State::ExitAsSoonAsPossible: + return; + default: + abort(); + } + } + } + + static void* ThreadFunc(void* arg) { + static_cast(arg)->ThreadFunc(); + return nullptr; + } + + // Called by the master thread to give this worker work to do. + // It is only legal to call this if the worker + void StartWork(Task* task) { + DCHECK(!task_.load()); + task_ = task; + DCHECK(state_.load(std::memory_order_acquire) == State::Ready); + ChangeState(State::HasWork); + } + + private: + // The underlying thread. + std::unique_ptr thread_; + + // The task to be worked on. + std::atomic task_; + + // The condition variable and mutex guarding state changes. + std::condition_variable state_cond_; + std::mutex state_mutex_; + + // The state enum tells if we're currently working, waiting for work, etc. + std::atomic state_; + + // pointer to the master's thread BlockingCounter object, to notify the + // master thread of when this worker switches to the 'Ready' state. + BlockingCounter* const counter_to_decrement_when_ready_; +}; + +class WorkersPool { + public: + WorkersPool() = default; + + void Execute(const std::vector>& tasks) { + CAFFE_ENFORCE_GE(tasks.size(), 1); + // One of the tasks will be run on the current thread. + int workers_count = tasks.size() - 1; + CreateWorkers(workers_count); + TORCH_DCHECK_LE(workers_count, (int)workers_.size()); + counter_to_decrement_when_ready_.Reset(workers_count); + for (const auto task : c10::irange(1, tasks.size())) { + workers_[task - 1]->StartWork(tasks[task].get()); + } + // Execute the remaining workload immediately on the current thread. + auto& task = tasks.front(); + task->Run(); + // Wait for the workers submitted above to finish. + counter_to_decrement_when_ready_.Wait(); + } + + private: + // Ensures that the pool has at least the given count of workers. + // If any new worker has to be created, this function waits for it to + // be ready. + void CreateWorkers(std::size_t workers_count) { + if (workers_.size() >= workers_count) { + return; + } + counter_to_decrement_when_ready_.Reset(workers_count - workers_.size()); + while (workers_.size() < workers_count) { + workers_.push_back(MakeAligned::make(&counter_to_decrement_when_ready_)); + } + counter_to_decrement_when_ready_.Wait(); + } + + C10_DISABLE_COPY_AND_ASSIGN(WorkersPool); + std::vector>> workers_; + // The BlockingCounter used to wait for the workers. + BlockingCounter counter_to_decrement_when_ready_; +}; +} // namespace caffe2 diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool-cpp.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool-cpp.h new file mode 100644 index 0000000000000000000000000000000000000000..f6fc5a2d8243a2d3beea412cc74e50d8d6cfff91 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool-cpp.h @@ -0,0 +1,55 @@ +#pragma once + +#ifdef USE_PTHREADPOOL + +#ifdef USE_INTERNAL_PTHREADPOOL_IMPL +#include +#else +#include +#endif + +#include +#include +#include + +namespace caffe2 { + +class PThreadPool final { + public: + explicit PThreadPool(size_t thread_count); + ~PThreadPool() = default; + + PThreadPool(const PThreadPool&) = delete; + PThreadPool& operator=(const PThreadPool&) = delete; + + PThreadPool(PThreadPool&&) = delete; + PThreadPool& operator=(PThreadPool&&) = delete; + + size_t get_thread_count() const; + void set_thread_count(size_t thread_count); + + // Run, in parallel, function fn(task_id) over task_id in range [0, range). + // This function is blocking. All input is processed by the time it returns. + void run(const std::function& fn, size_t range); + + private: + friend pthreadpool_t pthreadpool_(); + + private: + mutable std::mutex mutex_; + std::unique_ptr threadpool_; +}; + +// Return a singleton instance of PThreadPool for ATen/TH multithreading. +PThreadPool* pthreadpool(); +PThreadPool* pthreadpool(size_t thread_count); + +// Exposes the underlying implementation of PThreadPool. +// Only for use in external libraries so as to unify threading across +// internal (i.e. ATen, etc.) and external (e.g. NNPACK, QNNPACK, XNNPACK) +// use cases. +pthreadpool_t pthreadpool_(); + +} // namespace caffe2 + +#endif /* USE_PTHREADPOOL */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool.h new file mode 100644 index 0000000000000000000000000000000000000000..914ebf40a69906f6e68fa4f91b6e44a205ae50ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool.h @@ -0,0 +1,193 @@ +// pthreadpool header from https://github.com/Maratyszcza/pthreadpool +// for NNPACK +#ifndef CAFFE2_UTILS_PTHREADPOOL_H_ +#define CAFFE2_UTILS_PTHREADPOOL_H_ + +#include "ThreadPoolCommon.h" + +#include // for size_t +#include // for uint32_t + +#if defined(USE_PTHREADPOOL) +// This is a hack. +// Mainly introduced here because +// 1. NNPACK can be compiled to use internal legacy threadpool implementation because much of C2 depends on that. +// 2. Then if we want to use NNPACK in PyTorch, which uses new pthreadpool, then we will supply new pthreadpool pointer +// to NNPACK. This will not work if NNPACK is compiled with internal legacy threadpool. Thus this guard +// along with changes in pthreadpool_impl.cc allows us to override that behavior. +// It enables us to use NNPACK from pytorch using `caffe2::pthreadpool_()` +namespace caffe2 { +class WithCastToNewThreadPool { + public: + explicit WithCastToNewThreadPool(bool use_new_threadpool); + ~WithCastToNewThreadPool(); + private: + bool use_new_threadpool_; +}; +} +#endif + +typedef struct pthreadpool* legacy_pthreadpool_t; + +typedef void (*legacy_pthreadpool_function_1d_t)(void*, size_t); +typedef void (*legacy_pthreadpool_function_1d_tiled_t)(void*, size_t, size_t); +typedef void (*legacy_pthreadpool_function_2d_t)(void*, size_t, size_t); +typedef void (*legacy_pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*legacy_pthreadpool_function_3d_tiled_t)( + void*, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t); +typedef void (*legacy_pthreadpool_function_4d_tiled_t)( + void*, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t); + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Creates a thread pool with the specified number of threads. + * + * @param[in] threads_count The number of threads in the thread pool. + * A value of 0 has special interpretation: it creates a thread for each + * processor core available in the system. + * + * @returns A pointer to an opaque thread pool object. + * On error the function returns NULL and sets errno accordingly. + */ + +// Returns internal threadpool impl. +legacy_pthreadpool_t legacy_pthreadpool_create(size_t threads_count); + +/** + * Queries the number of threads in a thread pool. + * + * @param[in] threadpool The thread pool to query. + * + * @returns The number of threads in the thread pool. + */ +size_t legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool); + +/** + * Processes items in parallel using threads from a thread pool. + * + * When the call returns, all items have been processed and the thread pool is + * ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param[in] threadpool The thread pool to use for parallelisation. + * @param[in] function The function to call for each item. + * @param[in] argument The first argument passed to the @a function. + * @param[in] items The number of items to process. The @a function + * will be called once for each item. + */ +void legacy_pthreadpool_compute_1d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_t function, + void* argument, + size_t range); + +void legacy_pthreadpool_parallelize_1d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_t function, + void* argument, + size_t range, + uint32_t flags); + +void legacy_pthreadpool_compute_1d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_tiled_t function, + void* argument, + size_t range, + size_t tile); + +void legacy_pthreadpool_compute_2d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_t function, + void* argument, + size_t range_i, + size_t range_j); + +void legacy_pthreadpool_compute_2d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j); + +void legacy_pthreadpool_compute_3d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_3d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_i, + size_t tile_j, + size_t tile_k); + +void legacy_pthreadpool_compute_4d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_4d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_i, + size_t tile_j, + size_t tile_k, + size_t tile_l); + +/** + * Terminates threads in the thread pool and releases associated resources. + * + * @warning Accessing the thread pool after a call to this function constitutes + * undefined behaviour and may cause data corruption. + * + * @param[in,out] threadpool The thread pool to destroy. + */ +void legacy_pthreadpool_destroy(legacy_pthreadpool_t threadpool); + +#ifdef USE_INTERNAL_PTHREADPOOL_IMPL + +#define pthreadpool_t legacy_pthreadpool_t +#define pthreadpool_function_1d_t legacy_pthreadpool_function_1d_t +#define pthreadpool_function_1d_tiled_t legacy_pthreadpool_function_1d_tiled_t +#define pthreadpool_function_2d_t legacy_pthreadpool_function_2d_t +#define pthreadpool_function_2d_tiled_t legacy_pthreadpool_function_2d_tiled_t +#define pthreadpool_function_3d_tiled_t legacy_pthreadpool_function_3d_tiled_t +#define pthreadpool_function_4d_tiled_t legacy_pthreadpool_function_4d_tiled_t +#define pthreadpool_create legacy_pthreadpool_create +#define pthreadpool_destroy legacy_pthreadpool_destroy +#define pthreadpool_get_threads_count legacy_pthreadpool_get_threads_count +#define pthreadpool_compute_1d legacy_pthreadpool_compute_1d +#define pthreadpool_parallelize_1d legacy_pthreadpool_parallelize_1d +#define pthreadpool_compute_1d_tiled legacy_pthreadpool_compute_1d_tiled +#define pthreadpool_compute_2d legacy_pthreadpool_compute_2d +#define pthreadpool_compute_2d_tiled legacy_pthreadpool_compute_2d_tiled +#define pthreadpool_compute_3d_tiled legacy_pthreadpool_compute_3d_tiled +#define pthreadpool_compute_4d_tiled legacy_pthreadpool_compute_4d_tiled + +#endif /* USE_INTERNAL_PTHREADPOOL_IMPL */ + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif // CAFFE2_UTILS_PTHREADPOOL_H_ diff --git a/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/thread_pool_guard.h b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/thread_pool_guard.h new file mode 100644 index 0000000000000000000000000000000000000000..de244723dd812e4ef92c6b542a0812867edf7a34 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/thread_pool_guard.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace caffe2 { + +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct TORCH_API _NoPThreadPoolGuard { + static bool is_enabled(); + static void set_enabled(bool enabled); + + _NoPThreadPoolGuard(): prev_mode_(_NoPThreadPoolGuard::is_enabled()) { + _NoPThreadPoolGuard::set_enabled(true); + } + ~_NoPThreadPoolGuard() { + _NoPThreadPoolGuard::set_enabled(prev_mode_); + } + private: + bool prev_mode_; +}; + +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl.h new file mode 100644 index 0000000000000000000000000000000000000000..4865c759f41eb235719ab49c6c74508514cec85b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl.h @@ -0,0 +1,3988 @@ +/******************************************************************************* +* Copyright 2016-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C API + +#ifndef ONEAPI_DNNL_DNNL_H +#define ONEAPI_DNNL_DNNL_H + +#include "oneapi/dnnl/dnnl_common.h" +#include "oneapi/dnnl/dnnl_config.h" +#include "oneapi/dnnl/dnnl_types.h" +#include "oneapi/dnnl/dnnl_version.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_primitives +/// @{ + +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// Changes the primitive descriptor to point to the next available +/// implementation. +/// +/// @param primitive_desc A primitive descriptor to change. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +/// @returns #dnnl_last_impl_reached if no more implementations available, +/// in which case the primitive descriptor itself is kept unchanged. +dnnl_status_t DNNL_API dnnl_primitive_desc_next_impl( + dnnl_primitive_desc_t primitive_desc); + +/// Clones a primitive descriptor. The resulting primitive descriptor must be +/// destroyed separately. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param existing_primitive_desc Primitive descriptor to clone. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_desc_clone( + dnnl_primitive_desc_t *primitive_desc, + const_dnnl_primitive_desc_t existing_primitive_desc); + +/// Returns a constant reference to the attributes of a primitive descriptor. +/// +/// @warning +/// It is an error to destroy the resulting @p attr. +/// +/// @warning +/// The lifetime of an @p attr is the same as that of a @p +/// primitive_desc, so it is an error to use the @p attr once the @p +/// primitive_desc has been destroyed. +/// +/// @param primitive_desc Primitive descriptor. +/// @param attr Output primitive attributes. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr( + const_dnnl_primitive_desc_t primitive_desc, + const_dnnl_primitive_attr_t *attr); + +/// Destroys a primitive descriptor. +/// +/// @param primitive_desc Primitive descriptor to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_desc_destroy( + dnnl_primitive_desc_t primitive_desc); + +/// Queries a primitive descriptor for various pieces of information. +/// +/// The most common use case is to query a primitive descriptor, created with +/// source, weights, and destination memory descriptors with format tags set +/// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this +/// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and +/// #dnnl_query_dst_md respectively) so that it is possible to create memory +/// objects and reorder primitives if necessary. +/// +/// Another typical use case is to query a primitive descriptor for workspace +/// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this +/// query returns #dnnl_not_required status, then workspace memory is not +/// required. +/// +/// @note +/// When querying for a memory descriptor for a scratchpad, a workspace, +/// or an optional parameter, the query will return a pointer to a zero +/// memory descriptor if the parameter is not needed. +/// +/// A few other use cases: +/// - query a primitive descriptor for the implementation information string +/// (#dnnl_query_impl_info_str) +/// - query a primitive descriptor for the number of inputs and outputs +/// (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32 +/// respectively) +/// +/// @sa dnnl_query_t for more options +/// +/// @param primitive_desc Primitive descriptor. +/// @param what Parameter to query. +/// @param index Index of the parameter to query for. +/// @param result Output result. The type depends on the query. For example, +/// it must be a @c dnnl_memory_desc_t* if querying for a memory +/// descriptor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_desc_query( + const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, + int index, void *result); + +/// Queries primitive descriptor for a memory descriptor. +/// +/// @note +/// This function is a convenience version of +/// #dnnl_primitive_desc_query(). +/// +/// @param primitive_desc Primitive descriptor. +/// @param what Kind of memory descriptor parameter to query for. +/// @param index Index of the parameter to query. +/// @returns A pointer to the requested memory descriptor. +/// @returns A pointer to a zero memory descriptor if the parameter is not +/// needed. +/// @returns NULL in case of any error. +/// +const_dnnl_memory_desc_t DNNL_API dnnl_primitive_desc_query_md( + const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, + int index); + +/// Queries primitive descriptor for a signed 32bit int. +/// +/// @note +/// This function is a convenience version of +/// #dnnl_primitive_desc_query(). +/// +/// @param primitive_desc Primitive descriptor. +/// @param what Kind of the value to query for. +/// @param index Index of the parameter to query. +/// @returns The requested value. +/// @returns 0 in case of any error (in particular if the queried entity is +/// not of type int32_t). Note that 0 may also be the actual returned +/// value. +int DNNL_API dnnl_primitive_desc_query_s32( + const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, + int index); + +/// Creates a primitive. +/// +/// @param primitive Output primitive. +/// @param primitive_desc Primitive descriptor used to create the primitive. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, + const_dnnl_primitive_desc_t primitive_desc); + +/// Creates a primitive from a cache blob. +/// +/// @param primitive Output primitive. +/// @param primitive_desc Primitive descriptor used to create the primitive. +/// @param size Size of the cache blob in bytes. +/// @param cache_blob Cache blob of size @p size. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob( + dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc, + size_t size, const uint8_t *cache_blob); + +/// Executes a primitive. +/// +/// @param primitive Primitive to execute. +/// @param stream Stream to use. +/// @param nargs Number of arguments. +/// @param args Array of arguments. Each argument is an +/// pair. The index is one of the `DNNL_ARG_*` +/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see +/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory +/// descriptor as that returned by +/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. + +/// @note If any argument in @p args is padded (padded_dims > +/// dims), the primitive execution will assume properly zero-padded +/// input arguments, and produce zero-padded output arguments. +dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, + dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args); + +/// Retrieves a constant reference to the primitive descriptor of a given +/// primitive. +/// +/// @warning +/// It is an error to destroy the returned object. It is owned by the +/// primitive. The @c const qualifier of the returned object prevents +/// such attempts. +/// +/// @param primitive Primitive to query for the primitive descriptor. +/// @param primitive_desc Output primitive descriptor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc( + const_dnnl_primitive_t primitive, + const_dnnl_primitive_desc_t *primitive_desc); + +/// Retrieves a cache blob associated with the given primitive. +/// +/// @param primitive Primitive to query for the cache blob. +/// @param size Size of the cache blob in bytes. +/// @param cache_blob Cache blob of size @p size. If the @p cache_blob is +/// nullptr then the size of the cache blob is returned in @p size. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +/// +/// @note The cache blob can be empty. It's the user's responsibility to check +/// whether it's empty prior to passing it to +/// #dnnl_primitive_create_from_cache_blob(). +dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob( + const_dnnl_primitive_t primitive, size_t *size, uint8_t *cache_blob); + +/// Destroys a primitive. +/// +/// @param primitive The primitive to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive); + +/// @} dnnl_api_primitives_common + +/// @addtogroup dnnl_api_attributes +/// @{ + +/// Creates an empty (default) primitive attributes with all the parameters +/// set to their default values. +/// +/// Empty attributes are implied whenever the respective argument is NULL. +/// +/// @param attr Output primitive attributes. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr); + +/// Clones primitive attributes. +/// +/// @param attr Output primitive attributes. +/// @param existing_attr Primitive attributes to clone. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_clone( + dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr); + +/// Destroys primitive attributes. +/// +/// @param attr Primitive attributes to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr); + +/// Returns probability for output dropout primitive attribute. +/// +/// @param attr Primitive attributes. +/// @param dropout_desc Output dropout memory descriptor +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout( + const_dnnl_primitive_attr_t attr, + const_dnnl_memory_desc_t *dropout_desc); + +/// Sets probability for output dropout primitive attribute. +/// +/// @param attr Primitive attributes. +/// @param dropout_desc Output dropout memory descriptor +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout( + dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t dropout_desc); + +/// Returns the floating-point math mode primitive attribute. +/// +/// @param attr Primitive attributes. +/// @param mode Output FP math mode. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode( + const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode); + +/// Sets the floating-point math mode primitive attributes. +/// +/// @param attr Primitive attributes. +/// @param mode FP math mode. The possible values are: +/// #dnnl_fpmath_mode_strict (default), +/// #dnnl_fpmath_mode_bf16, +/// #dnnl_fpmath_mode_f16, +/// #dnnl_fpmath_mode_tf32, +/// #dnnl_fpmath_mode_any. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode( + dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode); + +/// Returns the floating-point math mode primitive attribute. +/// +/// @param attr Primitive attributes. +/// @param mode Output FP math mode. +/// @param apply_to_int Output use floating-point arithmetic for integer primitives. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode_v2( + const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode, + int *apply_to_int); + +/// Sets the floating-point math mode primitive attributes. +/// +/// @param attr Primitive attributes. +/// @param mode FP math mode. The possible values are: +/// #dnnl_fpmath_mode_strict (default), +/// #dnnl_fpmath_mode_bf16, +/// #dnnl_fpmath_mode_f16, +/// #dnnl_fpmath_mode_tf32, +/// #dnnl_fpmath_mode_any. +/// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode_v2( + dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode, int apply_to_int); + +/// Returns the deterministic primitive attribute value. +/// +/// @param attr Primitive attributes. +/// @param value Output deterministic attribute value +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_deterministic( + const_dnnl_primitive_attr_t attr, int *value); + +/// Sets the deterministic primitive attribute value. +/// +/// @param attr Primitive attributes. +/// @param value Boolean value to set deterministic attribute. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_deterministic( + dnnl_primitive_attr_t attr, int value); + +/// Returns the accumulation mode primitive attribute. +/// +/// @param attr Primitive attributes. +/// @param mode Output accumulation mode. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_accumulation_mode( + const_dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t *mode); + +/// Sets the accumulation mode primitive attribute. +/// +/// @param attr Primitive attributes. +/// @param mode Accumulation mode. The possible values are: +/// #dnnl_accumulation_mode_strict (default), which is s32 for quantized primitives, f32/f64 otherwise +/// #dnnl_accumulation_mode_relaxed, which is same as strict but allows intermediate accumulators to be in src/dst datatype +/// #dnnl_accumulation_mode_any, which allows accumulators to be src/dst datatype or any wider type. +/// #dnnl_accumulation_mode_f32, +/// #dnnl_accumulation_mode_s32, +/// #dnnl_accumulation_mode_f16. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_accumulation_mode( + dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t mode); + +/// Returns the primitive attributes scratchpad mode. +/// +/// @param attr Primitive attributes. +/// @param mode Output scratchpad mode. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode( + const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode); + +/// Sets primitive attributes scratchpad mode. +/// +/// @param attr Primitive attributes. +/// @param mode Scratchpad mode. The possible values are: +/// #dnnl_scratchpad_mode_library (default) and +/// #dnnl_scratchpad_mode_user. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode( + dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode); + +/// Sets primitive attributes scaling factors for primitive operations for a +/// given memory argument. The scaling factors must be passed at execution time +/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg. +/// +/// @sa dnnl_primitive_attr_set_scales_mask +/// +/// +/// @param attr Primitive attributes. +/// @param arg Parameter argument index as passed to the +/// dnnl_primitive_execute() call. +/// @param mask Scaling factors correspondence mask that defines the +/// correspondence between the tensor dimensions and the @p scales array. +/// The set i-th bit indicates that a dedicated scaling factor is used for +/// each index along that dimension. Set the mask to 0 to use a common +/// scaling factor for the whole output tensor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask( + dnnl_primitive_attr_t attr, int arg, int mask); + +/// Sets primitive attributes scaling factors for primitive operations for a +/// given memory argument. The scaling factors must be passed at execution time +/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg. +/// +/// @sa dnnl_primitive_attr_set_scales +/// +/// +/// @param attr Primitive attributes. +/// @param arg Parameter argument index as passed to the +/// dnnl_primitive_execute() call. +/// @param mask Scaling factors correspondence mask that defines the +/// correspondence between the tensor dimensions and the @p scales array. +/// The set i-th bit indicates that a dedicated scaling factor is used for +/// each index along that dimension. Set the mask to 0 to use a common +/// scaling factor for the whole output tensor. +/// @param ndims Number of group dimensions. +/// @param group_dims Scaling factors correspondence groups that define the +/// correspondence between the tensor dimensions and the scales array. +/// The group dimensions should only be provided for each logical dimension +/// that has correspondence mask @p mask set. +/// @param data_type Scaling factors data_type. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales( + dnnl_primitive_attr_t attr, int arg, int mask, int ndims, + const dnnl_dims_t group_dims, dnnl_data_type_t data_type); + +/// Sets primitive attributes zero points for primitive operations for a given +/// memory argument. The zero points must be passed at execution time +/// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg. +/// +/// @sa dnnl_primitive_attr_set_zero_points_mask +/// +/// +/// @param attr Primitive attributes. +/// @param arg Parameter argument index as passed to the +/// dnnl_primitive_execute() call. +/// @param mask Zero point correspondence mask that defines the +/// correspondence between the tensor dimensions and the @p +/// zero_points array. The set i-th bit indicates that a dedicated +/// zero point is used for each index along that dimension. Set the +/// mask to 0 to use a common zero point for the whole output tensor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask( + dnnl_primitive_attr_t attr, int arg, int mask); + +/// Sets primitive attributes zero points for primitive operations for a given +/// memory argument. The zero points must be passed at execution time +/// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg. +/// +/// @sa dnnl_primitive_attr_set_zero_points +/// +/// +/// @param attr Primitive attributes. +/// @param arg Parameter argument index as passed to the +/// dnnl_primitive_execute() call. +/// @param mask Zero point correspondence mask that defines the +/// correspondence between the tensor dimensions and the @p +/// zero_points array. The set i-th bit indicates that a dedicated +/// zero point is used for each index along that dimension. Set the +/// mask to 0 to use a common zero point for the whole output tensor. +/// @param ndims Number of group dimensions. +/// @param group_dims Zero point factors correspondence groups that define the +/// correspondence between the tensor dimensions and the zero_points array. +/// The group dimensions should be only provided for each logical dimension +/// that has the bit set correspondence mask @p mask set. +/// @param data_type Zero points factors data_type. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points( + dnnl_primitive_attr_t attr, int arg, int mask, int ndims, + const dnnl_dims_t group_dims, dnnl_data_type_t data_type); + +/// Sets the rounding mode attribute value for a given argument +/// +/// @param attr Primitive attributes. +/// @param arg Argument for which rounding mode should be set. +/// @param mode Rounding mode to apply to the argument. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding( + dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t mode); + +/// Returns the rounding mode attribute value for a given argument +/// +/// @param attr Primitive attributes. +/// @param arg Argument for which rounding mode query applies. +/// @param mode Output rounding mode. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_rounding( + dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t *mode); + +/// Returns primitive attributes post-ops. +/// +/// @warning +/// The output @p post_ops points to the internal @p attr field, so it is +/// an error to modify or destroy them. The lifetime of @p post_ops is +/// the same as that of the @p attr it belongs to, so it is an error to +/// use @p post_ops after @p attr has been destroyed. +/// +/// @param attr Primitive attributes. +/// @param post_ops Output post-ops. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops( + const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops); + +/// Sets primitive attributes post-ops. +/// +/// @note +/// There is no way to check whether the post-ops would be supported by +/// the target primitive. Any error will be reported by the +/// dnnl__[propagation kind]_primitive_desc_create() function call. +/// +/// @param attr Primitive attributes. +/// @param post_ops Post-ops to set. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops( + dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops); + +/// Creates empty post-ops sequence. +/// +/// @param post_ops Output post-ops. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops); + +/// Clones post-ops primitive attribute. +/// +/// @param post_ops Output post-ops primitive attribute. +/// @param existing_post_ops Post-ops primitive attribute to clone. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_clone( + dnnl_post_ops_t *post_ops, const_dnnl_post_ops_t existing_post_ops); + +/// Destroys post-ops. +/// +/// @param post_ops Post-ops to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops); + +/// Returns the length of post-ops. +/// +/// @param post_ops Post-ops. +/// @returns The number of post-ops entries. +int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops); + +/// Returns the kind of a post-op entry. +/// +/// @param post_ops Post-ops. +/// @param index Post-op entry index. +/// @returns The kind of the post-op with the specified index. +/// @returns #dnnl_undefined_primitive if there is no post-op at the specified +/// index. +dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind( + const_dnnl_post_ops_t post_ops, int index); + +/// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the +/// result, a zero point is subtracted from the previous value and is +/// multiplied by the scale. +/// +/// The kind of this post-op is #dnnl_sum. +/// +/// This feature may improve performance for cases like dequantize the +/// asymmetrically quantized sum's src1 tensor to f32 domain before performing +/// the sum operation by subtracting the @p zero_point before the scaling. +/// +/// In the simplest case where accumulation is the only post-op, the +/// computations will be: +/// +/// dst[:] <- scale * (dst[:] - zero_point) + op(...) +/// // instead of dst[:] <- op(...) +/// +/// If @p data_type is specified, original dst tensor will be reinterpreted +/// as a tensor with provided data type. Since it is reinterpretation, +/// data_type and dst data type should have the same size. +/// As a result, computations will be: +/// +/// dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...) +/// // instead of dst[:] <- op(...) +/// @note +/// This post-op executes in-place and does not change the +/// destination layout. +/// +/// @param post_ops Post-ops. +/// @param scale Accumulation scaling factor. +/// @param zero_point Single scalar int32_t value of zero point. +/// @param data_type Accumulation data_type. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops, + float scale, int32_t zero_point, dnnl_data_type_t data_type); + +/// Returns the parameters of an accumulation (sum) post-op with +/// zero point and data type parameter. +/// +/// @param post_ops Post-ops. +/// @param index Index of the sum post-op. +/// @param scale Output accumulation scaling factor. +/// @param zero_point Zero point. +/// @param data_type Data type for accumulation. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum( + const_dnnl_post_ops_t post_ops, int index, float *scale, + int32_t *zero_point, dnnl_data_type_t *data_type); + +/// Appends an elementwise post-op. +/// +/// The kind of this post operation is #dnnl_eltwise. +/// +/// In the simplest case when the elementwise is the only post operation, the +/// computations would be: +/// +/// dst[:] <- eltwise_op (op(...)) // instead of dst[:] <- op(...) +/// +/// where eltwise_op is configured with the given parameters. +/// +/// @param post_ops Post-ops. +/// @param alg_kind Elementwise algorithm for the post-op. +/// @param alpha Alpha parameter for the elementwise algorithm. +/// @param beta Beta parameter for the elementwise algorithm. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, + dnnl_alg_kind_t alg_kind, float alpha, float beta); + +/// Returns the parameters of an elementwise post-op. +/// +/// @param post_ops Post-ops. +/// @param index Index of the elementwise post-op. +/// @param alg_kind Output elementwise algorithm kind. +/// @param alpha Output alpha parameter for the elementwise algorithm. +/// @param beta Output beta parameter for the elementwise algorithm. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +/// @returns #dnnl_invalid_arguments if @p index does not refer to an +/// elementwise post-op. +dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise( + const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind, + float *alpha, float *beta); + +/// Appends a depthwise post-op convolution. +/// +/// This post-op can only be fused with a 2D 1x1 convolution (convolution with +/// weights spatial dimensions equal to 1 i.e., kh=kw=1). +/// +/// The kind of this post-op is #dnnl_convolution. +/// +/// The number of outputs for primitive with fusion is one. The output spatial +/// size can be derived as below: +/// +/// output_height = ceil(output_height_1x1_convolution, stride) +/// output_width = ceil(output_width_1x1_convolution, stride) +/// +/// See @ref dev_guide_attributes_post_ops_depthwise and +/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info. +/// +/// @param post_ops Post-ops. +/// @param weights_data_type Weights data type of depthwise post-op +/// @param bias_data_type Bias data type of depthwise post-op +/// @param dst_data_type Output data type of depthwise post-op +/// @param kernel_size Size of kernel of depthwise post-op +/// @param stride_size Size of stride of depthwise post-op +/// @param padding_l_size Size of left and top paddings of depthwise post-op +/// @returns #dnnl_success on success and a status describing the error +/// otherwise +dnnl_status_t DNNL_API dnnl_post_ops_append_dw(dnnl_post_ops_t post_ops, + dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, + dnnl_data_type_t dst_data_type, dnnl_dim_t kernel_size, + dnnl_dim_t stride_size, dnnl_dim_t padding_l_size); + +/// Returns the parameters of an depthwise post-op. +/// +/// @param post_ops Post-ops. +/// @param index Index of the elementwise post-op. +/// @param weights_data_type Weights data type of depthwise post-op +/// @param bias_data_type Bias data type of depthwise post-op +/// @param dst_data_type Output data type of depthwise post-op +/// @param kernel_size Size of kernel of depthwise post-op +/// @param stride_size Size of stride of depthwise post-op +/// @param padding_l_size Size of left and top paddings of depthwise post-op +/// @returns #dnnl_success on success and a status describing the error +/// otherwise +dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw( + const_dnnl_post_ops_t post_ops, int index, + dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, + dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size, + dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size); + +/// Appends a binary post-op. +/// +/// The kind of this post operation is #dnnl_binary. +/// +/// In the simplest case when the binary is the only post operation, the +/// computations would be: +/// +/// dst[:] <- binary_op (dst[:], another_input[:]) +/// +/// where binary_op is configured with the given parameters. binary_op supports +/// broadcast semantics for a second operand. +/// +/// @param post_ops Post-ops. +/// @param alg_kind Binary algorithm for the post-op. +/// @param src1_desc Memory descriptor of a second operand. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc); + +/// Returns the parameters of a binary post-op. +/// +/// @param post_ops Post-ops. +/// @param index Index of the binary post-op. +/// @param alg_kind Output binary algorithm kind. +/// @param src1_desc Output memory descriptor of a second operand. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +/// @returns #dnnl_invalid_arguments if @p index does not refer to a binary +/// post-op. +dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary( + const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind, + const_dnnl_memory_desc_t *src1_desc); + +/// Appends a prelu forward post-op. +/// +/// The kind of this post-op is #dnnl::primitive::kind::prelu. +/// +/// The post-op can be defined as: +/// +/// dst[:] <- prelu(dst[:], weights[:]) +/// prelu: +/// dst[:] <- dst[:] if dst[:] > 0 +/// dst[:] <- dst[:] * weights[:] if dst[:] <= 0 +/// +/// +/// @note +/// The order of dimensions does not depend on how elements are laid +/// out in memory. For example: +/// - for a 2D CNN activations tensor the order is always (n, c) +/// - for a 4D CNN activations tensor the order is always (n, c, h, w) +/// - for a 5D CNN weights tensor the order is always +/// (g, oc, ic, kh, kw) +/// +/// Prelu weights tensor is passed in runtime execution phase. Prelu +/// weights tensor data type is implicitly assumed as f32 using plain +/// layout (a, ab, acb, acdb, acdeb) +/// +/// @param post_ops Post-ops. +/// @param mask Defines the correspondence between the output tensor +/// dimensions and the prelu weights tensor. The set i-th bit indicates +/// that a dedicated weights value is used for each index along that +/// dimension. Set the mask to 0 to use a common weights value +/// for the whole output tensor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_append_prelu( + dnnl_post_ops_t post_ops, int mask); + +/// Returns the parameters of a prelu post-op. +/// +/// @param post_ops Post-ops. +/// @param index Index of the prelu post-op. +/// @param mask Mask of the prelu post-op. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu( + const_dnnl_post_ops_t post_ops, int index, int *mask); + +/// @} dnnl_api_attributes + +/// @} dnnl_api_primitives + +/// @addtogroup dnnl_api_memory +/// @{ + +/// Destroys a memory descriptor. +/// +/// @param memory_desc Memory descriptor to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_destroy(dnnl_memory_desc_t memory_desc); + +/// Clones a memory descriptor. The resulting memory descriptor must be +/// destroyed separately. +/// +/// @param memory_desc Output memory descriptor. +/// @param existing_memory_desc Memory descriptor to clone. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_clone(dnnl_memory_desc_t *memory_desc, + const_dnnl_memory_desc_t existing_memory_desc); + +/// Retrieves a binary blob associated with the given memory descriptor +/// +/// @param blob Output pointer to binary blob. +/// If not nullptr, size bytes of the memory descriptor blob are written. +/// @param size Output pointer to the size of the binary blob in bytes. +/// Size is written if blob is nullptr. +/// @param memory_desc input memory descriptor to serialize +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_get_blob( + uint8_t *blob, size_t *size, const_dnnl_memory_desc_t memory_desc); + +/// Creates a memory descriptor from a memory descriptor binary blob. +/// +/// @param memory_desc Output pointer to a newly allocated memory descriptor. +/// @param blob Pointer to a memory descriptor binary blob. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_with_blob( + dnnl_memory_desc_t *memory_desc, const uint8_t *blob); + +/// Creates a memory descriptor using dimensions and strides. +/// +/// @note +/// As always, the logical order of dimensions corresponds to the `abc...` +/// format tag, and the physical meaning of the dimensions depends on both +/// the primitive that consumes the memory and the context of that +/// consumption. +/// +/// @param memory_desc Output memory descriptor. +/// @param ndims Number of dimensions +/// @param dims Array of dimensions. +/// @param data_type Elements data type. +/// @param strides Strides in each dimension. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_with_strides( + dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, + dnnl_data_type_t data_type, const dnnl_dims_t strides); + +/// Creates a memory descriptor using dimensions and memory format tag. +/// +/// @note +/// As always, the logical order of dimensions corresponds to the `abc...` +/// format tag, and the physical meaning of the dimensions depends on both +/// the primitive that consumes the memory and the context of that +/// consumption. +/// +/// @param memory_desc Output memory descriptor. +/// @param ndims Number of dimensions +/// @param dims Array of dimensions. +/// @param data_type Elements data type. +/// @param tag Memory format tag. Can be #dnnl_format_tag_any which would +/// allow a primitive to chose the final memory format. In this case the +/// format_kind field of the memory descriptor would be set to +/// #dnnl_format_kind_any. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_with_tag( + dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, + dnnl_data_type_t data_type, dnnl_format_tag_t tag); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory descriptor for CSR encoding. +/// +/// @param memory_desc Output memory descriptor. +/// @param ndims Number of dimensions +/// @param dims Array of dimensions. +/// @param data_type Elements data type. +/// @param nnz Number of non-zero entries. +/// @param indices_dt Data type of indices. +/// @param pointers_dt Data type of pointers. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_with_csr_encoding( + dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, + dnnl_data_type_t data_type, dnnl_dim_t nnz, dnnl_data_type_t indices_dt, + dnnl_data_type_t pointers_dt); + +/// Creates a memory descriptor for COO encoding. +/// +/// The created memory descriptor will describe a memory object that +/// contains n+1 buffers for an n-dimensional tensor. +/// The buffers have the following meaning and assigned numbers (index): +/// - 0: values +/// - 1: indices for dimension 0 +/// - 2: indices for dimension 1 ... +/// - n: indices for dimension n-1 +/// +/// @param memory_desc Output memory descriptor. +/// @param ndims Number of dimensions. +/// @param dims Array of dimensions. +/// @param data_type Elements data type. +/// @param nnz Number of non-zero entries. +/// @param indices_dt Data type of indices. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_with_coo_encoding( + dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, + dnnl_data_type_t data_type, dnnl_dim_t nnz, + dnnl_data_type_t indices_dt); + +/// Creates a memory descriptor for packed sparse encoding. +/// +/// The created memory descriptor cannot be used to create a memory +/// object. It can only be used to create a primitive descriptor to +/// query the actual memory descriptor (similar to the format tag +/// `any`). +/// +/// @warning +/// The meaning and content of the handles of the memory object that +/// is created using the queried memory descriptor are unspecified +/// therefore using the content is an undefined behavior. +/// +/// @param memory_desc Output memory descriptor. +/// @param ndims Number of dimensions +/// @param dims Array of dimensions. +/// @param data_type Elements data type. +/// @param nnz Number of non-zero entries. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding( + dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, + dnnl_data_type_t data_type, dnnl_dim_t nnz); +#endif + +/// Creates a memory descriptor for a region inside an area +/// described by an existing memory descriptor. +/// +/// @warning +/// Some combinations of physical memory layout and/or offsets or dims may +/// result in a failure to create a submemory. +// +/// @param memory_desc Output memory descriptor. +/// @param parent_memory_desc An existing memory descriptor. +/// @param dims Sizes of the region. +/// @param offsets Offsets to the region from the encompassing +/// memory object in each dimension +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_create_submemory( + dnnl_memory_desc_t *memory_desc, + const_dnnl_memory_desc_t parent_memory_desc, const dnnl_dims_t dims, + const dnnl_dims_t offsets); + +/// Creates a memory descriptor by reshaping an existing one. The new +/// memory descriptor inherits the data type. This operation is valid only for +/// memory descriptors that have format_kind #dnnl_blocked or +/// #dnnl_format_kind_any. +/// +/// The resulting memory descriptor must be destroyed separately. +/// +/// The operation ensures the transformation of the physical memory format +/// corresponds to the transformation of the logical dimensions. If such +/// transformation is impossible, the function returns #dnnl_invalid_arguments. +/// +/// The reshape operation can be described as a combination of the following +/// basic operations: +/// 1. Add a dimension of size `1`. This is always possible. +/// 2. Remove a dimension of size `1`. This is possible only if the dimension +/// has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`). +/// 3. Split a dimension into multiple ones. This is possible only if the size +/// of the dimension is exactly equal to the product of the split ones and +/// the dimension does not have padding (i.e. +/// `padded_dims[dim] = dims[dim]`). +/// 4. Joining multiple consecutive dimensions into a single one. As in the +/// cases above, this requires that the dimensions do not have padding and +/// that the memory format is such that in physical memory these dimensions +/// are dense and have the same order as their logical counterparts. This +/// also assumes that these dimensions are not blocked. +/// - Here, dense means: +/// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`; +/// - And same order means: +/// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`. +/// +/// @warning +/// Some combinations of physical memory layout and/or offsets or +/// dimensions may result in a failure to make a reshape. +/// +/// @param out_memory_desc Output memory descriptor. +/// @param in_memory_desc An existing memory descriptor. Must have format_kind +/// set to #dnnl_blocked or #dnnl_format_kind_any. +/// @param ndims Number of dimensions for the output memory descriptor. +/// @param dims Dimensions for the output memory descriptor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_reshape( + dnnl_memory_desc_t *out_memory_desc, + const_dnnl_memory_desc_t in_memory_desc, int ndims, + const dnnl_dims_t dims); + +/// Creates a memory descriptor by permuting axes in an existing one. +/// +/// The physical memory layout representation is adjusted accordingly to +/// maintain the consistency between the logical and physical parts of the +/// memory descriptor. +/// +/// The resulting memory descriptor must be destroyed separately. +/// +/// The new memory descriptor inherits the data type. This operation is valid +/// only for memory descriptors that have format_kind set to #dnnl_blocked or +/// #dnnl_format_kind_any. +/// +/// The logical axes will be permuted in the following manner: +/// ``` +/// for (i: 0 .. in_memory_desc->ndims) +/// out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i]; +/// ``` +/// +/// Example: +/// @code +/// dnnl_memory_desc_t in_md, out_md, expect_out_md; +/// +/// const int permutation[] = {1, 0}; // swap the first and the second axes +/// +/// dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2}; +/// dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba; +/// +/// dnnl_memory_desc_create_with_tag( +/// &in_md, 2, in_dims, data_type, in_tag); +/// dnnl_memory_desc_create_with_tag( +/// &expect_out_md, 2, out_dims, data_type, out_tag); +/// +/// dnnl_memory_desc_permute_axes(&out_md, in_md, permutation); +/// assert(dnnl_memory_desc_equal(out_md, expect_out_md)); +/// +/// dnnl_memory_desc_destroy(in_md); +/// dnnl_memory_desc_destroy(out_md); +/// dnnl_memory_desc_destroy(expect_out_md); +/// @endcode +/// +/// @param out_memory_desc Output memory descriptor. +/// @param in_memory_desc An existing memory descriptor. Must have format_kind +/// set to #dnnl_blocked or #dnnl_format_kind_any. +/// @param permutation Axes permutation (of size `in_memory_desc->ndims`). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes( + dnnl_memory_desc_t *out_memory_desc, + const_dnnl_memory_desc_t in_memory_desc, const int *permutation); + +/// Queries a memory descriptor for various pieces of information. +/// +/// The following information can be queried: +/// - Number of dimensions (#dnnl_query_ndims_s32) +/// - Dimensions (#dnnl_query_dims) in the following order: +/// - CNN data tensors: mini-batch, channel, spatial +/// ({N, C, [[D,] H,] W}) +/// - CNN weight tensors: group (optional), output channel, input channel, +/// spatial ({[G,] O, I, [[D,] H,] W}) +/// - RNN data tensors: time, mini-batch, channels ({T, N, C}) +/// or layers, directions, states, mini-batch, channels +/// ({L, D, S, N, C}) +/// - RNN weight tensor: layers, directions, input channel, gates, output +/// channels ({L, D, I, G, O}) +/// - Data type of the tensor elements (#dnnl_query_data_type) +/// - Padded dimensions (#dnnl_query_padded_dims) - size of the data including +/// padding in each dimension +/// - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from +/// the padding to actual data, the top-level tensor with offsets applied +/// must lie within the padding area. +/// - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory +/// origin to the current block, non-zero only in a description of a memory +/// sub-block. +/// - Format kind (#dnnl_query_format_kind) - memory format kind +/// +/// @note +/// The order of dimensions does not depend on the memory format, so +/// whether the data is laid out in #dnnl_nchw or #dnnl_nhwc +/// the dims for 4D CN data tensor would be {N, C, H, W}. +/// +/// The following queries are applicable only to format kind #dnnl_blocked. +/// - Strides (#dnnl_query_strides) between the outermost blocks or in case +/// of plain (non-blocked) formats the strides between dimensions +/// - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g. +/// `{4, 16, 4}` in case of `OIhw_4i16o4i` +/// - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case +/// of `OIhw_4i16o4i_` +/// - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}` +/// in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim +/// +/// @param memory_desc Memory descriptor. +/// @param what Parameter to query. +/// @param result Output result. The type depends on the query. For example, +/// it must be a @c dnnl_dims_t** if querying for a strides. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_query( + const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, void *result); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Queries a memory descriptor for various pieces of information. This version +/// support additional queries #dnnl_query_sparse_encoding, #dnnl_query_nnz_s64 +/// #dnnl_query_num_handles_s32 and #dnnl_query_data_type for a particular +/// buffer. +/// +/// The following information can be queried: +/// - Number of dimensions (#dnnl_query_ndims_s32) +/// - Dimensions (#dnnl_query_dims) in the following order: +/// - CNN data tensors: mini-batch, channel, spatial +/// ({N, C, [[D,] H,] W}) +/// - CNN weight tensors: group (optional), output channel, input channel, +/// spatial ({[G,] O, I, [[D,] H,] W}) +/// - RNN data tensors: time, mini-batch, channels ({T, N, C}) +/// or layers, directions, states, mini-batch, channels +/// ({L, D, S, N, C}) +/// - RNN weight tensor: layers, directions, input channel, gates, output +/// channels ({L, D, I, G, O}) +/// - Data type of the tensor elements (#dnnl_query_data_type) +/// - Padded dimensions (#dnnl_query_padded_dims) - size of the data including +/// padding in each dimension +/// - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from +/// the padding to actual data, the top-level tensor with offsets applied +/// must lie within the padding area. +/// - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory +/// origin to the current block, non-zero only in a description of a memory +/// sub-block. +/// - Format kind (#dnnl_query_format_kind) - memory format kind +/// +/// @note +/// The order of dimensions does not depend on the memory format, so +/// whether the data is laid out in #dnnl_nchw or #dnnl_nhwc +/// the dims for 4D CN data tensor would be {N, C, H, W}. +/// +/// The following queries are applicable only to format kind #dnnl_blocked. +/// - Strides (#dnnl_query_strides) between the outermost blocks or in case +/// of plain (non-blocked) formats the strides between dimensions +/// - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g. +/// `{4, 16, 4}` in case of `OIhw_4i16o4i` +/// - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case +/// of `OIhw_4i16o4i_` +/// - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}` +/// in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim +/// +/// @param memory_desc Memory descriptor. +/// @param what Parameter to query. +/// @param index Index of the parameter to query for. It is mostly used with +/// #dnnl_query_data_type to specify which data type is being queried. +/// The main data type (data type of values) has always index 0. For other +/// indices please refer to the API for creating a memory descriptor for +/// sparse encoding. +/// @param result Output result. The type depends on the query. For example, +/// it must be a @c dnnl_dims_t** if querying for a strides. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_desc_query_v2( + const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, int index, + void *result); +#endif + +/// Compares two memory descriptors. +/// +/// Use this function to identify whether a reorder is required between the +/// two memories +/// +/// @param lhs Left-hand side of the comparison. +/// @param rhs Right-hand side of the comparison. +/// @returns 1 if the descriptors are the same. +/// @returns 0 if the descriptors are different. +int DNNL_API dnnl_memory_desc_equal( + const_dnnl_memory_desc_t lhs, const_dnnl_memory_desc_t rhs); + +/// Returns the size of a memory descriptor. +/// +/// @param memory_desc Memory descriptor. +/// @returns The number of bytes required for memory described by a memory +/// descriptor. +size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Returns the size of the data that corresponds to the given index. +/// +/// @param memory_desc Memory descriptor. +/// @param index Index of the buffer. +/// +/// @returns The number of bytes required for the requested data. +size_t DNNL_API dnnl_memory_desc_get_size_v2( + const_dnnl_memory_desc_t memory_desc, int index); +#endif + +/// Returns the size of data type. +/// +/// @param data_type Data type. +/// @returns The number of bytes occupied by data type. +size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type); + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory +/// object will have the underlying buffer set. In this case, the buffer will +/// be initialized as if dnnl_memory_set_data_handle() had been called. +/// +/// @sa dnnl_memory_set_data_handle() +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer for the memory object. In this case the library +/// owns the buffer. +/// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + void *handle); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param nhandles Number of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer for the memory object. In this case the library +/// owns the buffer. +/// - DNNL_MEMORY_NONE Instructs the library to skip allocation of the +/// memory buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_create_v2(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + int nhandles, void **handles); +#endif + +/// Returns the memory descriptor for a memory object. +/// +/// @param memory Memory object. +/// @param memory_desc Output memory descriptor (a copy). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_get_memory_desc( + const_dnnl_memory_t memory, const_dnnl_memory_desc_t *memory_desc); + +/// Returns the engine of a memory object. +/// +/// @param memory Memory object. +/// @param engine Output engine on which the memory is located. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_get_engine( + const_dnnl_memory_t memory, dnnl_engine_t *engine); + +/// Maps a memory object and returns a host-side pointer to a memory buffer +/// with a copy of its contents. +/// +/// Mapping enables explicit direct access to memory contents for the engines +/// that do not support it implicitly. +/// +/// Mapping is an exclusive operation - a memory object cannot be used in +/// other operations until this memory object is unmapped. +/// +/// @note +/// Any primitives working with @p memory should be completed before +/// the memory is mapped. Use dnnl_stream_wait to synchronize the +/// corresponding execution stream. +/// +/// @note +/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are +/// mainly provided for debug and testing purposes, and their performance +/// may be suboptimal. +/// +/// @param memory Memory object. +/// @param mapped_ptr Output pointer to the mapped buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_map_data( + const_dnnl_memory_t memory, void **mapped_ptr); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Maps a memory object and returns a host-side pointer to a memory buffer +/// with a copy of its contents. The memory buffer corresponds to the given +/// index. +/// +/// Mapping enables explicit direct access to memory contents for the engines +/// that do not support it implicitly. +/// +/// Mapping is an exclusive operation - a memory object cannot be used in +/// other operations until this memory object is unmapped. +/// +/// @note +/// Any primitives working with @p memory should be completed before +/// the memory is mapped. Use dnnl_stream_wait to synchronize the +/// corresponding execution stream. +/// +/// @note +/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are +/// mainly provided for debug and testing purposes, and their performance +/// may be suboptimal. +/// +/// @param memory Memory object. +/// @param mapped_ptr Output pointer to the mapped buffer. +/// @param index Index of the buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_map_data_v2( + const_dnnl_memory_t memory, void **mapped_ptr, int index); +#endif + +/// Unmaps a memory object and writes back any changes made to the previously +/// mapped memory buffer. The pointer to the mapped buffer must be obtained +/// via the dnnl_memory_map_data() call. +/// +/// @note +/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are +/// mainly provided for debug and testing purposes, and their performance +/// may be suboptimal. +/// +/// @param memory Memory object. +/// @param mapped_ptr Pointer to the mapped buffer that must have been +/// obtained using the dnnl_memory_map_data() function. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_unmap_data( + const_dnnl_memory_t memory, void *mapped_ptr); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Unmaps a memory object and writes back any changes made to the previously +/// mapped memory buffer. The pointer to the mapped buffer must be obtained +/// via the dnnl_memory_map_data() call. The buffer corresponds to the given +/// index. +/// +/// @note +/// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are +/// mainly provided for debug and testing purposes, and their performance +/// may be suboptimal. +/// +/// @param memory Memory object. +/// @param mapped_ptr Pointer to the mapped buffer that must have been +/// obtained using the dnnl_memory_map_data() function. +/// @param index Index of the buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_unmap_data_v2( + const_dnnl_memory_t memory, void *mapped_ptr, int index); +#endif + +/// Returns memory object's data handle. +/// +/// @param memory Memory object. +/// @param handle Output data handle. For the CPU engine, the data handle is a +/// pointer to the actual data. For OpenCL it is a cl_mem. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_get_data_handle( + const_dnnl_memory_t memory, void **handle); + +/// Sets the underlying memory buffer. +/// +/// @param memory Memory object. +/// @param handle Data handle. For the CPU engine or when USM is used, the +/// memory buffer is a pointer to the actual data. For OpenCL it is a +/// `cl_mem`. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_set_data_handle( + dnnl_memory_t memory, void *handle); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Returns an underlying memory buffer that corresponds to the given index. +/// +/// @param memory Memory object. +/// @param handle Data handle. For the CPU engine or when USM is used, the +/// memory buffer is a pointer to the actual data. For OpenCL it is a +/// `cl_mem`. +/// @param index Index of the buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_get_data_handle_v2( + const_dnnl_memory_t memory, void **handle, int index); + +/// Sets an underlying memory buffer that corresponds to the given index. +/// +/// @param memory Memory object. +/// @param handle Data handle. For the CPU engine or when USM is used, the +/// memory buffer is a pointer to the actual data. For OpenCL it is a +/// `cl_mem`. +/// @param index Index of the buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2( + dnnl_memory_t memory, void *handle, int index); +#endif + +/// Destroys a memory object. +/// +/// @param memory Memory object to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory); + +/// @} dnnl_api_memory + +/// @addtogroup dnnl_api_primitives +/// @{ + +/// @addtogroup dnnl_api_reorder +/// @{ + +/// Creates a primitive descriptor for a reorder primitive. +/// +/// @param reorder_primitive_desc Output primitive descriptor. +/// @param src_desc Source memory descriptor. +/// @param src_engine Engine on which the source memory object will be +/// located. +/// @param dst_desc Destination memory descriptor. +/// @param dst_engine Engine on which the destination memory object +/// will be located. +/// @param attr Primitive attributes to use (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create( + dnnl_primitive_desc_t *reorder_primitive_desc, + const_dnnl_memory_desc_t src_desc, dnnl_engine_t src_engine, + const_dnnl_memory_desc_t dst_desc, dnnl_engine_t dst_engine, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_reorder + +/// @addtogroup dnnl_api_concat +/// @{ + +/// Creates a primitive descriptor for an out-of-place concatenation +/// primitive. +/// +/// @param concat_primitive_desc Output primitive descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param n Number of source parameters. +/// @param concat_dimension Source tensors will be concatenated over +/// dimension with this index. Note that order of dimensions does +/// not depend on memory format. +/// @param src_descs Array of source memory descriptors with @p n elements. +/// @param attr Primitive attributes to use (can be NULL). +/// @param engine Engine to use. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create( + dnnl_primitive_desc_t *concat_primitive_desc, dnnl_engine_t engine, + const_dnnl_memory_desc_t dst_desc, int n, int concat_dimension, + const_dnnl_memory_desc_t const *src_descs, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_concat + +/// @addtogroup dnnl_api_sum +/// @{ + +/// Creates a primitive descriptor for an (out-of-place) sum primitive. +/// +/// @param sum_primitive_desc Output primitive descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param n Number of source parameters. +/// @param scales Vector of scales to multiply data in each source +/// memory by. +/// @param src_descs Array of source memory descriptors having @p n elements. +/// @param attr Primitive attributes to use (can be NULL). +/// @param engine Engine to use. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create( + dnnl_primitive_desc_t *sum_primitive_desc, dnnl_engine_t engine, + const_dnnl_memory_desc_t dst_desc, int n, const float *scales, + const_dnnl_memory_desc_t const *src_descs, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_sum + +/// @addtogroup dnnl_api_binary +/// @{ + +/// Creates a primitive descriptor for a binary primitive. +/// +/// @note +/// Memory descriptors @p src1_desc and @p dst_desc are allowed to be +/// initialized with #dnnl_format_tag_any or with format_kind set to +/// #dnnl_format_kind_any. +/// +/// @note +/// Both memory descriptors must have the same number of dimensions. +/// Element broadcasting is supported for memory descriptor @p src1_desc +/// and are applied to @p src1_desc dimensions that have size equal to 1. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add, +/// #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div, +/// #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le, +/// #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne. +/// @param src0_desc Source 0 memory descriptor. +/// @param src1_desc Source 1 memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc, + const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a binary primitive with support of +/// ternary operators. +/// +/// @note +/// Memory descriptors @p src1_desc, @p src2_desc and @p dst_desc are +/// allowed to be initialized with #dnnl_format_tag_any or with format_kind +/// set to #dnnl_format_kind_any. +/// +/// @note +/// All memory descriptors must have the same number of dimensions. +/// Element broadcasting is supported for memory descriptor @p src1_desc +/// and is applied to @p src1_desc dimensions that have a size equal to 1. +/// There is no broadcasting support for @p src2_desc. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Algorithm kind. +/// @param src0_desc Source 0 memory descriptor. +/// @param src1_desc Source 1 memory descriptor. +/// @param src2_desc Source memory descriptor for ternary operations. Might +/// be empty. +/// @param dst_desc Destination memory descriptor. +/// @param attr Primitive attributes. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create_v2( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc, + const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t src2_desc, + const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_binary + +/// @addtogroup dnnl_api_convolution +/// @{ + +/// Creates a primitive descriptor for a convolution forward propagation +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain +/// values for spatial dimensions only and hence must have the same number of +/// elements as there are spatial dimensions. The order of values is the same +/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), +/// and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param alg_kind Convolution algorithm. Possible values are +/// #dnnl_convolution_direct, #dnnl_convolution_winograd, +/// #dnnl_convolution_auto. +/// @param src_desc Source memory descriptor. +/// @param weights_desc Weights memory descriptor. +/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory +/// descriptor, or a memory descriptor with format_kind set to +/// #dnnl_format_kind_undef disables the bias term. +/// @param dst_desc Destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param dilates Array of dilations for spatial dimension. A zero value +/// means no dilation in the corresponding dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, + const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc, + const dnnl_dims_t strides, const dnnl_dims_t dilates, + const dnnl_dims_t padding_l, const dnnl_dims_t padding_r, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a convolution backward propagation +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain +/// values for spatial dimensions only and hence must have the same number of +/// elements as there are spatial dimensions. The order of values is the same +/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), +/// and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Convolution algorithm. Possible values are +/// #dnnl_convolution_direct, #dnnl_convolution_winograd, +/// #dnnl_convolution_auto. +/// @param diff_src_desc Diff source memory descriptor. +/// @param weights_desc Weights memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param dilates Array of dilations for spatial dimension. A zero value +/// means no dilation in the corresponding dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides, + const dnnl_dims_t dilates, const dnnl_dims_t padding_l, + const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a convolution weights gradient primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain +/// values for spatial dimensions only and hence must have the same number of +/// elements as there are spatial dimensions. The order of values is the same +/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), +/// and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Convolution algorithm. Possible values are +/// #dnnl_convolution_direct, #dnnl_convolution_winograd, +/// #dnnl_convolution_auto. +/// @param src_desc Source memory descriptor. +/// @param diff_weights_desc Diff weights memory descriptor. +/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero +/// memory descriptor, or a memory descriptor with format_kind set to +/// #dnnl_format_kind_undef disables the bias term. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param dilates Array of dilations for spatial dimension. A zero value +/// means no dilation in the corresponding dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t diff_weights_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides, + const dnnl_dims_t dilates, const dnnl_dims_t padding_l, + const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_convolution + +/// @addtogroup dnnl_api_deconvolution +/// @{ + +/// Creates a primitive descriptor for a deconvolution forward propagation +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain +/// values for spatial dimensions only and hence must have the same number of +/// elements as there are spatial dimensions. The order of values is the same +/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), +/// and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param alg_kind Deconvolution algorithm. Possible values are +/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. +/// @param src_desc Source memory descriptor. +/// @param weights_desc Weights memory descriptor. +/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory +/// descriptor, or a memory descriptor with format_kind set to +/// #dnnl_format_kind_undef disables the bias term. +/// @param dst_desc Destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param dilates Array of dilations for spatial dimension. A zero value +/// means no dilation in the corresponding dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, + const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc, + const dnnl_dims_t strides, const dnnl_dims_t dilates, + const dnnl_dims_t padding_l, const dnnl_dims_t padding_r, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a deconvolution backward propagation +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain +/// values for spatial dimensions only and hence must have the same number of +/// elements as there are spatial dimensions. The order of values is the same +/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), +/// and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Deconvolution algorithm. Possible values are +/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. +/// @param diff_src_desc Diff source memory descriptor. +/// @param weights_desc Weights memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param dilates Array of dilations for spatial dimension. A zero value +/// means no dilation in the corresponding dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides, + const dnnl_dims_t dilates, const dnnl_dims_t padding_l, + const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a deconvolution weights gradient +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain +/// values for spatial dimensions only and hence must have the same number of +/// elements as there are spatial dimensions. The order of values is the same +/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors), +/// and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Deconvolution algorithm. Possible values are +/// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd. +/// @param src_desc Source memory descriptor. +/// @param diff_weights_desc Diff weights memory descriptor. +/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero +/// memory descriptor, or a memory descriptor with format_kind set to +/// #dnnl_format_kind_undef disables the bias term. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param dilates Array of dilations for spatial dimension. A zero value +/// means no dilation in the corresponding dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API +dnnl_deconvolution_backward_weights_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t diff_weights_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides, + const dnnl_dims_t dilates, const dnnl_dims_t padding_l, + const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_deconvolution + +/// @addtogroup dnnl_api_shuffle +/// @{ + +/// Creates a primitive descriptor for a shuffle forward propagation primitive +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param axis The axis along which the data is shuffled. +/// @param group_size Shuffle group size. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, int axis, dnnl_dim_t group_size, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a shuffle backward propagation primitive +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param axis The axis along which the data is shuffled. +/// @param group_size Shuffle group size. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, int axis, dnnl_dim_t group_size, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_shuffle + +/// @addtogroup dnnl_api_eltwise +/// @{ + +/// Creates a primitive descriptor for an eltwise forward propagation primitive. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param alg_kind Elementwise algorithm kind. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param alpha The alpha parameter for the elementwise operation. Specific +/// meaning depends on the algorithm. +/// @param beta The beta parameter for the elementwise operation. Specific +/// meaning depends on the algorithm. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, + const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc, + float alpha, float beta, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for an eltwise backward propagation +/// primitive. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Elementwise algorithm kind. +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param data_desc Destination memory descriptor if one of the +/// "use_dst_for_bwd" algorithms are used (such as +/// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor otherwise. +/// @param alpha The alpha parameter for the elementwise operation. Specific +/// meaning depends on the algorithm. +/// @param beta The beta parameter for the elementwise operation. Specific +/// meaning depends on the algorithm. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t data_desc, float alpha, float beta, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_eltwise + +/// @addtogroup dnnl_api_softmax +/// @{ + +/// Creates a primitive descriptor for a softmax forward propagation primitive. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or +/// #dnnl_softmax_log. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param softmax_axis Axis over which softmax is computed. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, + const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc, + int softmax_axis, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a softmax backward propagation primitive. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or +/// #dnnl_softmax_log. +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param softmax_axis Axis over which softmax is computed. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t dst_desc, int softmax_axis, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_softmax + +/// @addtogroup dnnl_api_pooling +/// @{ + +/// Creates a primitive descriptor for a pooling forward propagation +/// primitive. +/// +/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r +/// contain values for spatial dimensions only and hence must have the same +/// number of elements as there are spatial dimensions. The order of values +/// is the same as in the tensor: depth (for 3D tensors), +/// height (for 3D and 2D tensors), and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max, +/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param kernel Array of kernel spatial dimensions. +/// @param dilation Array of dilations for spatial dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, + const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc, + const dnnl_dims_t strides, const dnnl_dims_t kernel, + const dnnl_dims_t dilation, const dnnl_dims_t padding_l, + const dnnl_dims_t padding_r, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a pooling backward propagation +/// primitive. +/// +/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r +/// contain values for spatial dimensions only and hence must have the same +/// number of elements as there are spatial dimensions. The order of values +/// is the same as in the tensor: depth (for 3D tensors), +/// height (for 3D and 2D tensors), and width. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max, +/// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding. +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param strides Array of strides for spatial dimension. +/// @param kernel Array of kernel spatial dimensions. +/// @param dilation Array of dilations for spatial dimension. +/// @param padding_l Array of padding values for low indices for each spatial +/// dimension `([[front,] top,] left)`. +/// @param padding_r Array of padding values for high indices for each spatial +/// dimension `([[back,] bottom,] right)`. Can be NULL in which case +/// padding is considered to be symmetrical. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides, + const dnnl_dims_t kernel, const dnnl_dims_t dilation, + const dnnl_dims_t padding_l, const dnnl_dims_t padding_r, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_pooling + +/// @addtogroup dnnl_api_prelu +/// @{ + +/// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable +/// alpha parameter) forward propagation primitive. +/// +/// @note +/// weights descriptor is allowed to be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param src_desc Source memory descriptor. +/// @param weights_desc Alpha parameters memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable +/// alpha parameter) backward propagation primitive. +/// +/// @note +/// weights descriptor and diff_weights descriptor are allowed +/// to be initialized with #dnnl_format_tag_any or with format_kind +/// set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param src_desc Source memory descriptor. +/// @param weights_desc Alpha parameters memory descriptor. +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_weights_desc Diff alpha parameters memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_weights_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_prelu + +/// @addtogroup dnnl_api_lrn +/// @{ + +/// Creates a primitive descriptor for an LRN forward propagation primitive. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or +/// #dnnl_lrn_within_channel. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param local_size Regularization local size. +/// @param alpha The alpha regularization parameter. +/// @param beta The beta regularization parameter. +/// @param k The k regularization parameter. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, + const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc, + dnnl_dim_t local_size, float alpha, float beta, float k, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for an LRN backward propagation primitive. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or +/// #dnnl_lrn_within_channel. +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param src_desc Source memory descriptor. +/// @param local_size Regularization local size. +/// @param alpha The alpha regularization parameter. +/// @param beta The beta regularization parameter. +/// @param k The k regularization parameter. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t src_desc, dnnl_dim_t local_size, float alpha, + float beta, float k, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_lrn + +/// @addtogroup dnnl_api_batch_normalization +/// @{ + +/// Creates a primitive descriptor for a batch normalization forward propagation +/// primitive. +/// +/// @note +/// In-place operation is supported: the dst can refer to the same memory +/// as the src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param epsilon Batch normalization epsilon parameter. +/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t). +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, float epsilon, unsigned flags, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a batch normalization backward +/// propagation primitive. +/// +/// @note +/// In-place operation is supported: the diff_dst can refer to the same +/// memory as the diff_src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are +/// computed in this case). +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param src_desc Source memory descriptor. +/// @param epsilon Batch normalization epsilon parameter. +/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t). +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t src_desc, float epsilon, unsigned flags, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_batch_normalization + +/// @addtogroup dnnl_api_group_normalization +/// @{ + +/// Creates a primitive descriptor for a group normalization forward propagation +/// primitive. +/// +/// @note +/// In-place operation is supported: the dst can refer to the same memory +/// as the src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param groups Group normalization groups parameter. +/// @param epsilon Group normalization epsilon parameter. +/// @param flags Group normalization flags (@ref dnnl_normalization_flags_t). +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_group_normalization_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, dnnl_dim_t groups, float epsilon, + unsigned flags, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a group normalization backward +/// propagation primitive. +/// +/// @note +/// In-place operation is supported: the diff_dst can refer to the same +/// memory as the diff_src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are +/// computed in this case). +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param src_desc Source memory descriptor. +/// @param groups Group normalization groups parameter. +/// @param epsilon Group normalization epsilon parameter. +/// @param flags Group normalization flags (@ref dnnl_normalization_flags_t). +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_group_normalization_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t src_desc, dnnl_dim_t groups, float epsilon, + unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_group_normalization + +/// @addtogroup dnnl_api_layer_normalization +/// @{ + +/// Creates a primitive descriptor for a layer normalization forward propagation +/// primitive. +/// +/// @note +/// In-place operation is supported: the dst can refer to the same memory +/// as the src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param stat_desc Memory descriptor for mean and variance. If this +/// parameter is NULL, a zero memory descriptor, or a memory descriptor +/// with format_kind set to #dnnl_format_kind_undef, then the memory +/// descriptor for stats is derived from @p src_desc by removing the last +/// dimension. +/// @param epsilon Layer normalization epsilon parameter. +/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc, + float epsilon, unsigned flags, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a layer normalization backward +/// propagation primitive. +/// +/// @note +/// In-place operation is supported: the diff_dst can refer to the same +/// memory as the diff_src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are +/// computed in this case). +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param src_desc Source memory descriptor. +/// @param stat_desc Memory descriptor for mean and variance. If this +/// parameter is NULL, a zero memory descriptor, or a memory descriptor +/// with format_kind set to #dnnl_format_kind_undef, then the memory +/// descriptor for stats is derived from @p src_desc by removing the last +/// dimension. +/// @param epsilon Layer normalization epsilon parameter. +/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc, + float epsilon, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a layer normalization forward propagation +/// primitive with a user-provided data type for the scale and shift +/// memory objects. +/// +/// @note +/// In-place operation is supported: the dst can refer to the same memory +/// as the src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param stat_desc Memory descriptor for mean and variance. If this +/// parameter is NULL, a zero memory descriptor, or a memory descriptor +/// with format_kind set to #dnnl_format_kind_undef, then the memory +/// descriptor for stats is derived from @p src_desc by removing the last +/// dimension. +/// @param scale_shift_data_type Data type of scale and shift memory. If neither scale +/// nor shift flag are specified the parameter is ignored. +/// @param epsilon Layer normalization epsilon parameter. +/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API +dnnl_layer_normalization_forward_primitive_desc_create_v2( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc, + dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a layer normalization backward +/// propagation primitive with a user-provided data type for the +/// scale and shift memory objects. +/// +/// @note +/// In-place operation is supported: the diff_dst can refer to the same +/// memory as the diff_src. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are +/// computed in this case). +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param src_desc Source memory descriptor. +/// @param stat_desc Memory descriptor for mean and variance. If this +/// parameter is NULL, a zero memory descriptor, or a memory descriptor +/// with format_kind set to #dnnl_format_kind_undef, then the memory +/// descriptor for stats is derived from @p src_desc by removing the last +/// dimension. +/// @param diff_scale_shift_data_type Data type of diff scale and shift memory. If neither scale +/// nor shift flag are specified the parameter is ignored. +/// @param scale_shift_data_type Data type of scale and shift memory. If neither scale +/// nor shift flag are specified the parameter is ignored. +/// @param epsilon Layer normalization epsilon parameter. +/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t). +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API +dnnl_layer_normalization_backward_primitive_desc_create_v2( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc, + dnnl_data_type_t diff_scale_shift_data_type, + dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_layer_normalization + +/// @addtogroup dnnl_api_inner_product +/// @{ + +/// Creates a primitive descriptor for an inner product forward propagation +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param src_desc Source memory descriptor. +/// @param weights_desc Weights memory descriptor. +/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory +/// descriptor, or a memory descriptor with format_kind set to +/// #dnnl_format_kind_undef disables the bias term. +/// @param dst_desc Destination memory descriptor. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for an inner product backward propagation +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param diff_src_desc Diff source memory descriptor. +/// @param weights_desc Weights memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for an inner product weights gradient +/// primitive. +/// +/// @note +/// Memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive_descriptor. +/// @param engine Engine to use. +/// @param src_desc Source memory descriptor. +/// @param diff_weights_desc Diff weights memory descriptor. +/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero +/// memory descriptor, or a memory descriptor with format_kind set to +/// #dnnl_format_kind_undef disables the bias term. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API +dnnl_inner_product_backward_weights_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t diff_weights_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_inner_product + +/// @addtogroup dnnl_api_attributes +/// @{ + +/// Set quantization scale and shift parameters for RNN data tensors. +/// +/// For performance reasons, the low-precision configuration of the RNN +/// primitives expects input activations to have the unsigned 8-bit integer +/// data type. The scale and shift parameters are used to quantize +/// floating-point data to unsigned integer and must be passed to the RNN +/// primitive using attributes. +/// +/// The quantization formula is `scale * data + shift`. +/// +/// @note +/// Quantization scale and shift are common for src_layer, src_iter, +/// dst_iter, and dst_layer. +/// +/// Example usage: +/// @code +/// // RNN parameters +/// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; +/// // Activations quantization parameters +/// float scale = 63.f, shift = 64.f; +/// +/// dnnl_primitive_attr_t rnn_attr; +/// // Create default attributes +/// dnnl_primitive_attr_create(&rnn_attr); +/// +/// // Set scale and shift for int8 quantization of activation +/// dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift); +/// +/// // Create an RNN primitive descriptor. +/// dnnl_primitive_desc_t rnn_pd; +/// dnnl_vanilla_rnn_forward_primitive_desc_create(&rnn_pd, +/// engine, /* arguments */, attr); +/// @endcode +/// +/// @param attr Primitive attributes. +/// @param scale The value to scale the data by. +/// @param shift The value to shift the data by. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams( + dnnl_primitive_attr_t attr, const float scale, const float shift); + +/// Returns the quantization scale and shift parameters for RNN data tensors. +/// +/// @note +/// Quantization scale and shift are common for src_layer, src_iter, +/// dst_iter, and dst_layer. +/// +/// @param attr Primitive attributes. +/// @param scale The value to scale the data by. +/// @param shift The value to shift the data by. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams( + const_dnnl_primitive_attr_t attr, float *scale, float *shift); + +/// Sets quantization scaling factors for RNN weights tensors. The +/// low-precision configuration of the RNN primitives expects input weights to +/// use the signed 8-bit integer data type. The scaling factors are used to +/// quantize floating-point data to signed integer and must be passed to RNN +/// primitives using attributes. +/// +/// @note +/// The dimension order is always native and does not depend on the actual +/// layout used. For example, five-dimensional weights always have (l, d, +/// i, g, o) logical dimension ordering. +/// +/// @note +/// Quantization scales are common for weights_layer and weights_iteration +/// +/// @param attr Primitive attributes. +/// @param count Number of elements in the @p scales array. +/// @param mask Scaling factors correspondence mask that defines the +/// correspondence between the output tensor dimensions and the @p +/// scales vector. The set i-th bit indicates that a dedicated scaling +/// factor should be used for each index along that dimension. Set the +/// mask to 0 to use a common scaling factor for the whole output +/// tensor. +/// @param scales Array of output scaling factors that must contain @p count +/// values and the following equality must hold: +/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] +/// Violations can only be detected when the attributes are used to create +/// a primitive descriptor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams( + dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, + const float *scales); + +/// Returns the quantization scaling factors for RNN weights tensors. +/// +/// @param attr Primitive attributes. +/// @param count Number of elements in the @p scales array. +/// @param mask Scaling factors correspondence mask that defines the +/// correspondence between the output tensor dimensions and the @p +/// scales vector. The set i-th bit indicates that a dedicated scaling +/// factor should be used for each index along that dimension. Set the +/// mask to 0 to use a common scaling factor for the whole output +/// tensor. +/// @param scales Array of output scaling factors that contain @p count +/// values and the following equality must hold: +/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams( + const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, + const float **scales); + +/// Sets quantization scaling factors for RNN projection weights tensors. The +/// low-precision configuration of the RNN primitives expects input weights to +/// use the signed 8-bit integer data type. The scaling factors are used to +/// quantize floating-point data to signed integer and must be passed to RNN +/// primitives using attributes. +/// +/// @note +/// The dimension order is always native and does not depend on the actual +/// layout used. For example, five-dimensional weights always have (l, d, +/// i, g, o) logical dimension ordering. +/// +/// @param attr Primitive attributes. +/// @param count Number of elements in the @p scales array. +/// @param mask Scaling factors correspondence mask that defines the +/// correspondence between the output tensor dimensions and the @p +/// scales vector. The set i-th bit indicates that a dedicated scaling +/// factor should be used for each index along that dimension. Set the +/// mask to 0 to use a common scaling factor for the whole output +/// tensor. +/// @param scales Array of output scaling factors that must contain @p count +/// values and the following equality must hold: +/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] +/// Violations can only be detected when the attributes are used to create +/// a primitive descriptor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams( + dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, + const float *scales); + +/// Returns the quantization scaling factors for RNN projection weights tensors. +/// +/// @param attr Primitive attributes. +/// @param count Number of elements in the @p scales array. +/// @param mask Scaling factors correspondence mask that defines the +/// correspondence between the output tensor dimensions and the @p +/// scales vector. The set i-th bit indicates that a dedicated scaling +/// factor should be used for each index along that dimension. Set the +/// mask to 0 to use a common scaling factor for the whole output +/// tensor. +/// @param scales Array of output scaling factors that contain @p count +/// values and the following equality must hold: +/// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f] +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams( + const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, + const float **scales); + +/// @} dnnl_api_attributes + +/// @addtogroup dnnl_api_rnn +/// @{ + +/// Creates a primitive descriptor for vanilla RNN forward propagation +/// primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc, +/// - @p bias_desc, +/// - @p dst_iter_desc. +/// +/// This would then indicate that the RNN forward propagation primitive should +/// not use them and should default to zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu, +/// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param flags Unused. +/// @param alpha Negative slope if activation is #dnnl_eltwise_relu. +/// @param beta Unused. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, + const dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha, + float beta, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for vanilla RNN backward propagation +/// primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc together with @p diff_src_iter_desc, +/// - @p bias_desc together with @p diff_bias_desc, +/// - @p dst_iter_desc together with @p diff_dst_iter_desc. +/// +/// This would then indicate that the RNN backward propagation primitive should +/// not use the respective data and should use zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Must be #dnnl_backward. +/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu, +/// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param diff_src_layer_desc Memory descriptor for the diff of input vector. +/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent +/// hidden state vector. +/// @param diff_weights_layer_desc Memory descriptor for the diff of weights +/// applied to the layer input. +/// @param diff_weights_iter_desc Memory descriptor for the diff of weights +/// applied to the recurrent input. +/// @param diff_bias_desc Diff bias memory descriptor. +/// @param diff_dst_layer_desc Memory descriptor for the diff of output +/// vector. +/// @param diff_dst_iter_desc Memory descriptor for the diff of output +/// recurrent hidden state vector. +/// @param flags Unused. +/// @param alpha Negative slope if activation is #dnnl_eltwise_relu. +/// @param beta Unused. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, + const dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, + const_dnnl_memory_desc_t diff_src_layer_desc, + const_dnnl_memory_desc_t diff_src_iter_desc, + const_dnnl_memory_desc_t diff_weights_layer_desc, + const_dnnl_memory_desc_t diff_weights_iter_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_layer_desc, + const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, + float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for an LSTM forward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc together with @p src_iter_c_desc, +/// - @p weights_peephole_desc, +/// - @p bias_desc, +/// - @p dst_iter_desc together with @p dst_iter_c_desc. +/// +/// This would then indicate that the LSTM forward propagation primitive should +/// not use them and should default to zero values instead. +/// +/// The @p weights_projection_desc could either be @c NULL or point to a zero +/// memory descriptor. This would then indicate that the LSTM doesn't have +/// recurrent projection layer. +/// +/// @note +/// All memory descriptors can be initialized with #dnnl_format_tag_any or +/// with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param src_iter_c_desc Memory descriptor for the input recurrent cell +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param weights_peephole_desc Memory descriptor for the weights applied to +/// the cell states (according to the Peephole LSTM formula). +/// @param weights_projection_desc Memory descriptor for the weights applied to +/// the hidden states to get the recurrent projection (according to the +/// Projection LSTM formula). +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell +/// state vector. +/// @param flags Unused. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t src_iter_c_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t weights_peephole_desc, + const_dnnl_memory_desc_t weights_projection_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, + const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for an LSTM backward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc, +/// and @p diff_src_iter_c_desc, +/// - @p weights_peephole_desc together with @p diff_weights_peephole_desc, +/// - @p bias_desc together with @p diff_bias_desc, +/// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc, +/// and @p diff_dst_iter_c_desc. +/// +/// This would then indicate that the LSTM backward propagation primitive +/// should not use them and should default to zero values instead. +/// +/// The @p weights_projection_desc together with @p +/// diff_weights_projection_desc could either be @c NULL or point to a zero +/// memory descriptor. This would then indicate that the LSTM doesn't have +/// recurrent projection layer. +/// +/// @note +/// All memory descriptors can be initialized with #dnnl_format_tag_any or +/// with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Must be #dnnl_backward. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param src_iter_c_desc Memory descriptor for the input recurrent cell +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param weights_peephole_desc Memory descriptor for the weights applied to +/// the cell states (according to the Peephole LSTM formula). +/// @param weights_projection_desc Memory descriptor for the weights applied to +/// the hidden states to get the recurrent projection (according to the +/// Projection LSTM formula). +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell +/// state vector. +/// @param diff_src_layer_desc Memory descriptor for the diff of input vector. +/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent +/// hidden state vector. +/// @param diff_src_iter_c_desc Memory descriptor for the diff of input +/// recurrent cell state vector. +/// @param diff_weights_layer_desc Memory descriptor for the diff of weights +/// applied to the layer input. +/// @param diff_weights_iter_desc Memory descriptor for the diff of weights +/// applied to the recurrent input. +/// @param diff_weights_peephole_desc Memory descriptor for the diff of weights +/// applied to the cell states (according to the Peephole LSTM formula). +/// @param diff_weights_projection_desc Memory descriptor for the diff of +/// weights applied to the hidden states to get the recurrent projection +/// (according to the Projection LSTM formula). +/// @param diff_bias_desc Diff bias memory descriptor. +/// @param diff_dst_layer_desc Memory descriptor for the diff of output +/// vector. +/// @param diff_dst_iter_desc Memory descriptor for the diff of output +/// recurrent hidden state vector. +/// @param diff_dst_iter_c_desc Memory descriptor for the diff of output +/// recurrent cell state vector. +/// @param flags Unused. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t src_iter_c_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t weights_peephole_desc, + const_dnnl_memory_desc_t weights_projection_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, + const_dnnl_memory_desc_t dst_iter_c_desc, + const_dnnl_memory_desc_t diff_src_layer_desc, + const_dnnl_memory_desc_t diff_src_iter_desc, + const_dnnl_memory_desc_t diff_src_iter_c_desc, + const_dnnl_memory_desc_t diff_weights_layer_desc, + const_dnnl_memory_desc_t diff_weights_iter_desc, + const_dnnl_memory_desc_t diff_weights_peephole_desc, + const_dnnl_memory_desc_t diff_weights_projection_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_layer_desc, + const_dnnl_memory_desc_t diff_dst_iter_desc, + const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for GRU forward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc, +/// - @p bias_desc, +/// - @p dst_iter_desc. +/// +/// This would then indicate that the GRU forward propagation primitive should +/// not use them and should default to zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param flags Unused. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for GRU backward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc together with @p diff_src_iter_desc, +/// - @p bias_desc together with @p diff_bias_desc, +/// - @p dst_iter_desc together with @p diff_dst_iter_desc. +/// +/// This would then indicate that the GRU backward propagation primitive +/// should not use them and should default to zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Must be #dnnl_backward. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param diff_src_layer_desc Memory descriptor for the diff of input vector. +/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent +/// hidden state vector. +/// @param diff_weights_layer_desc Memory descriptor for the diff of weights +/// applied to the layer input. +/// @param diff_weights_iter_desc Memory descriptor for the diff of weights +/// applied to the recurrent input. +/// @param diff_bias_desc Diff bias memory descriptor. +/// @param diff_dst_layer_desc Memory descriptor for the diff of output +/// vector. +/// @param diff_dst_iter_desc Memory descriptor for the diff of output +/// recurrent hidden state vector. +/// @param flags Unused. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, + const_dnnl_memory_desc_t diff_src_layer_desc, + const_dnnl_memory_desc_t diff_src_iter_desc, + const_dnnl_memory_desc_t diff_weights_layer_desc, + const_dnnl_memory_desc_t diff_weights_iter_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_layer_desc, + const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a descriptor for LBR GRU forward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc, +/// - @p bias_desc, +/// - @p dst_iter_desc. +/// +/// This would then indicate that the LBR GRU forward propagation primitive +/// should not use them and should default to zero values instead. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param flags Unused. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for LBR GRU backward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc together with @p diff_src_iter_desc, +/// - @p bias_desc together with @p diff_bias_desc, +/// - @p dst_iter_desc together with @p diff_dst_iter_desc. +/// +/// This would then indicate that the LBR GRU backward propagation primitive +/// should not use them and should default to zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Must be #dnnl_backward. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param diff_src_layer_desc Memory descriptor for the diff of input vector. +/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent +/// hidden state vector. +/// @param diff_weights_layer_desc Memory descriptor for the diff of weights +/// applied to the layer input. +/// @param diff_weights_iter_desc Memory descriptor for the diff of weights +/// applied to the recurrent input. +/// @param diff_bias_desc Diff bias memory descriptor. +/// @param diff_dst_layer_desc Memory descriptor for the diff of output +/// vector. +/// @param diff_dst_iter_desc Memory descriptor for the diff of output +/// recurrent hidden state vector. +/// @param flags Unused. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, + const_dnnl_memory_desc_t diff_src_layer_desc, + const_dnnl_memory_desc_t diff_src_iter_desc, + const_dnnl_memory_desc_t diff_weights_layer_desc, + const_dnnl_memory_desc_t diff_weights_iter_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_layer_desc, + const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for AUGRU forward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc, +/// - @p bias_desc, +/// - @p dst_iter_desc. +/// +/// This would then indicate that the AUGRU forward propagation primitive should +/// not use them and should default to zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param attention_desc Memory descriptor for the attention vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param flags Unused. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t attention_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for AUGRU backward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc together with @p diff_src_iter_desc, +/// - @p bias_desc together with @p diff_bias_desc, +/// - @p dst_iter_desc together with @p diff_dst_iter_desc. +/// +/// This would then indicate that the AUGRU backward propagation primitive +/// should not use them and should default to zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Must be #dnnl_backward. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param attention_desc Memory descriptor for the attention vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param diff_src_layer_desc Memory descriptor for the diff of input vector. +/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent +/// hidden state vector. +/// @param diff_attention_desc Memory descriptor for the diff of attention vector. +/// @param diff_weights_layer_desc Memory descriptor for the diff of weights +/// applied to the layer input. +/// @param diff_weights_iter_desc Memory descriptor for the diff of weights +/// applied to the recurrent input. +/// @param diff_bias_desc Diff bias memory descriptor. +/// @param diff_dst_layer_desc Memory descriptor for the diff of output +/// vector. +/// @param diff_dst_iter_desc Memory descriptor for the diff of output +/// recurrent hidden state vector. +/// @param flags Unused. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t attention_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, + const_dnnl_memory_desc_t diff_src_layer_desc, + const_dnnl_memory_desc_t diff_src_iter_desc, + const_dnnl_memory_desc_t diff_attention_desc, + const_dnnl_memory_desc_t diff_weights_layer_desc, + const_dnnl_memory_desc_t diff_weights_iter_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_layer_desc, + const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for LBR AUGRU forward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc, +/// - @p bias_desc, +/// - @p dst_iter_desc. +/// +/// This would then indicate that the LBR AUGRU forward propagation primitive +/// should not use them and should default to zero values instead. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param attention_desc Memory descriptor for the attention vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param flags Unused. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t attention_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, + const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for LBR AUGRU backward propagation primitive. +/// +/// The following arguments may either be @c NULL or point to a zero memory +/// descriptor: +/// - @p src_iter_desc together with @p diff_src_iter_desc, +/// - @p bias_desc together with @p diff_bias_desc, +/// - @p dst_iter_desc together with @p diff_dst_iter_desc. +/// +/// This would then indicate that the LBR AUGRU backward propagation primitive +/// should not use them and should default to zero values instead. +/// +/// @note +/// All memory descriptors can be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Must be #dnnl_backward. +/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more +/// info. +/// @param src_layer_desc Memory descriptor for the input vector. +/// @param src_iter_desc Memory descriptor for the input recurrent hidden +/// state vector. +/// @param attention_desc Memory descriptor for the attention vector. +/// @param weights_layer_desc Memory descriptor for the weights applied to the +/// layer input. +/// @param weights_iter_desc Memory descriptor for the weights applied to the +/// recurrent input. +/// @param bias_desc Bias memory descriptor. +/// @param dst_layer_desc Memory descriptor for the output vector. +/// @param dst_iter_desc Memory descriptor for the output recurrent hidden +/// state vector. +/// @param diff_src_layer_desc Memory descriptor for the diff of input vector. +/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent +/// hidden state vector. +/// @param diff_attention_desc Memory descriptor for the diff of attention vector. +/// @param diff_weights_layer_desc Memory descriptor for the diff of weights +/// applied to the layer input. +/// @param diff_weights_iter_desc Memory descriptor for the diff of weights +/// applied to the recurrent input. +/// @param diff_bias_desc Diff bias memory descriptor. +/// @param diff_dst_layer_desc Memory descriptor for the diff of output +/// vector. +/// @param diff_dst_iter_desc Memory descriptor for the diff of output +/// recurrent hidden state vector. +/// @param flags Unused. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, + const_dnnl_memory_desc_t src_layer_desc, + const_dnnl_memory_desc_t src_iter_desc, + const_dnnl_memory_desc_t attention_desc, + const_dnnl_memory_desc_t weights_layer_desc, + const_dnnl_memory_desc_t weights_iter_desc, + const_dnnl_memory_desc_t bias_desc, + const_dnnl_memory_desc_t dst_layer_desc, + const_dnnl_memory_desc_t dst_iter_desc, + const_dnnl_memory_desc_t diff_src_layer_desc, + const_dnnl_memory_desc_t diff_src_iter_desc, + const_dnnl_memory_desc_t diff_attention_desc, + const_dnnl_memory_desc_t diff_weights_layer_desc, + const_dnnl_memory_desc_t diff_weights_iter_desc, + const_dnnl_memory_desc_t diff_bias_desc, + const_dnnl_memory_desc_t diff_dst_layer_desc, + const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_rnn + +/// @addtogroup dnnl_api_matmul +/// @{ + +/// Creates a primitive descriptor for a matrix multiplication primitive. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param src_desc Source memory descriptor (matrix A) +/// @param weights_desc Weights memory descriptor (matrix B) +/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory +/// descriptor, or a memory descriptor with format_kind set to +/// #dnnl_format_kind_undef disables the bias term. +/// @param dst_desc Destination memory descriptor (matrix C). +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t weights_desc, + const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_matmul + +/// @addtogroup dnnl_api_resampling Resampling +/// @{ + +/// Creates a primitive descriptor for a resampling forward propagation +/// primitive. +/// +/// @note +/// Destination memory descriptor is allowed to be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param prop_kind Propagation kind. Possible values are +/// #dnnl_forward_training and #dnnl_forward_inference. +/// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest, +/// or #dnnl_resampling_linear. +/// @param factors Array of scaling factors for spatial dimension. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, + const float *factors, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr); + +/// Creates a primitive descriptor for a resampling backward propagation +/// primitive. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind resamplinging algorithm kind: either +/// #dnnl_resampling_nearest, or #dnnl_resampling_linear. +/// @param diff_src_desc Diff source memory descriptor. +/// @param diff_dst_desc Diff destination memory descriptor. +/// @param factors Array of scaling factors for spatial dimension. +/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation +/// primitive. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +/// +dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const float *factors, + const_dnnl_memory_desc_t diff_src_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_primitive_desc_t hint_fwd_pd, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_resampling + +/// @addtogroup dnnl_api_reduction Reduction +/// @{ + +/// Creates a primitive descriptor for a reduction primitive. +/// +/// @note +/// Destination memory descriptor is allowed to be initialized with +/// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any. +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param alg_kind reduction algorithm kind. Possible values: +/// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum, +/// #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max, +/// #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max, +/// #dnnl_reduction_norm_lp_power_p_sum. +/// @param p Algorithm specific parameter. +/// @param eps Algorithm specific parameter. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param attr Primitive attributes (can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_reduction_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine, + dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, float p, float eps, + const_dnnl_primitive_attr_t attr); + +/// @} dnnl_api_reduction + +/// @} dnnl_api_primitives + +/// @addtogroup dnnl_api_primitive_cache +/// @{ + +/// Returns the number of primitives that can be held in the primitive cache +/// at the same time. +/// +/// @param capacity Primitive cache capacity to query. Concurrently +/// accessing @p capacity is safe. +/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the +/// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on +/// success. +dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity); + +/// Sets a number of primitives that can be held in the primitive cache +/// at a time. +/// +/// @param capacity Primitive cache capacity to set. If a new @p capacity is +/// less than a number of primitives that the primitive cache already has +/// then the excess entries will be evicted. Setting the @p capacity to 0 +/// clears the primitive cache and disables it. Concurrently modifying +/// @p capacity is safe. +/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the +/// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on +/// success. +dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity); + +/// @} dnnl_api_primitive_cache + +/// @addtogroup dnnl_api_service +/// @{ + +/// Configures dumping of JIT-generated code. +/// +/// @note +/// This setting overrides the DNNL_JIT_DUMP environment variable. +/// +/// @param enable Flag value. Set to 0 to disable and set to 1 to enable. +/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the +/// @p flag value is invalid, and #dnnl_success/#dnnl::status::success on +/// success. +dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable); + +/// Sets library profiling flags. The flags define which profilers are +/// supported. +/// +/// @note +/// This setting overrides DNNL_JIT_PROFILE environment variable. +/// +/// @sa @ref dev_guide_profilers +/// +/// @param flags Profiling flags that can contain the following bits: +/// - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Profiler +/// (on by default) +/// - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific +/// jit-pid.dump output (off by default). The location of the output +/// is controlled via JITDUMPDIR environment variable or via +/// dnnl_set_jit_profiling_jitdumpdir() function. +/// - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific +/// perf-pid.map output (off by default). The output is always placed +/// into /tmp. +/// +/// Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely. +/// +/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the +/// @p flags value is invalid, and #dnnl_success/#dnnl::status::success on +/// success. +dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags); + +/// Sets JIT dump output path. Only applicable to Linux and is only +/// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set. +/// +/// After the first JIT kernel is generated, the jitdump output will be placed +/// into temporary directory created using the mkdtemp template +/// 'dir/.debug/jit/dnnl.XXXXXX'. +/// +/// @sa @ref dev_guide_profilers +/// +/// @note +/// This setting overrides JITDUMPDIR environment variable. If +/// JITDUMPDIR is not set, and this function is never called, the path +/// defaults to HOME. Passing NULL reverts the value to default. +/// +/// @note +/// The directory is accessed only when the first JIT kernel is being +/// created. JIT profiling will be disabled in case of any errors +/// accessing or creating this directory. +/// +/// @param dir JIT dump output path. +/// @returns #dnnl_success/#dnnl::status::success if the +/// output directory was set correctly and an error status otherwise. +/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows. +dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir); + +/// Sets the maximal ISA the library can dispatch to on the CPU. See +/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by +/// the C and C++ API functions respectively. +/// +/// This function has effect only once, and returns an error on subsequent +/// calls. It should also be invoked before any other oneDNN API call, otherwise +/// it may return an error. +/// +/// This function overrides the DNNL_MAX_CPU_ISA environment variable. The +/// environment variable can be set to the desired maximal ISA name in upper +/// case and with dnnl_cpu_isa prefix removed. For example: +/// `DNNL_MAX_CPU_ISA=AVX2`. +/// +/// @note +/// The ISAs are only partially ordered: +/// - SSE41 < AVX < AVX2 < AVX2_VNNI < AVX2_VNNI_2, +/// - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16 +/// < AVX10_1_512 < AVX10_1_512_AMX < AVX10_1_512_AMX_FP16, +/// - AVX2_VNNI < AVX10_1_512. +/// Aliases: +/// - AVX512_CORE_FP16 = AVX10_1_512 +/// - AVX512_CORE_AMX = AVX10_1_512_AMX +/// - AVX512_CORE_AMX_FP16 = AVX10_1_512_AMX_FP16 +/// +/// @sa @ref dev_guide_cpu_dispatcher_control for more details +/// +/// @param isa Maximal ISA the library should dispatch to. Pass +/// #dnnl_cpu_isa_default/#dnnl::cpu_isa::isa_default to remove ISA restrictions +/// (except for ISAs with initial support in the library). +/// @returns #dnnl_success/#dnnl::status::success on success and a +/// #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa +/// parameter is invalid or the ISA cannot be changed at this time. +/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature +/// was disabled at build time (see @ref dev_guide_build_options for more +/// details). +dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa); + +/// Gets the maximal ISA the library can dispatch to on the CPU. See +/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by +/// the C and C++ API functions respectively. +/// +/// @sa @ref dev_guide_cpu_dispatcher_control for more details +/// +/// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may +/// dispatch to. +dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void); + +/// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and +/// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++ +/// API functions respectively. +/// +/// This function has effect only once, and returns an error on subsequent +/// calls. It should also be invoked before any other oneDNN API call, otherwise +/// it may return an error. +/// +/// This function overrides the DNNL_CPU_ISA_HINTS environment variable. +/// @sa @ref dev_guide_cpu_isa_hints for more details +/// +/// @param isa_hints CPU ISA hints to be passed over to the implementation. +/// Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use +/// default features i.e. no hints. +/// @returns #dnnl_success/#dnnl::status::success on success and a +/// #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot +/// be specified at the current time. +/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature +/// was disabled at build time (see @ref dev_guide_build_options for more +/// details). +dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints); + +/// Gets the ISA specific hints that library can follow. See +/// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values +/// returned by the C and C++ API functions respectively. +/// +/// @sa @ref dev_guide_cpu_isa_hints for more details +/// +/// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the +/// library can follow. +dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void); + +/// @} dnnl_api_service + +#ifdef DNNL_EXPERIMENTAL_PROFILING + +/// @addtogroup dnnl_api_profiling Profiling +/// @{ + +/// Resets a profiler's state. +/// +/// @param stream Stream associated with the profiler. +/// +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_reset_profiling(dnnl_stream_t stream); + +/// Queries profiling data. The profiling data accumulates for each primitive +/// execution. The @p num_entries will be equal to the number of executions +/// since the last `dnnl_reset_profiling` call. In order to query the +/// @p num_entries the @p data parameter should be NULL. When @p data is NULL +/// then the @p data_kind parameter is ignored. +/// +/// The profiling data can be reset by calling #dnnl_reset_profiling. +/// +/// @note +/// It is required to wait for all submitted primitives to complete +/// using #dnnl_stream_wait prior to querying profiling data. +/// +/// @param stream Stream that was used for executing a primitive that +/// is being profiled. +/// @param data_kind Profiling data kind to query. +/// @param num_entries Number of profiling data entries. +/// @param data Profiling data. +/// +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_query_profiling_data(dnnl_stream_t stream, + dnnl_profiling_data_kind_t data_kind, int *num_entries, uint64_t *data); + +/// @} dnnl_api_profiling +#endif + +/// @addtogroup dnnl_api_blas +/// @{ + +/// Performs single-precision matrix-matrix multiply. +/// +/// The operation is defined as: +/// +/// `C := alpha * op( A ) * op( B ) + beta * C` +/// +/// where +/// - `op( X ) = X` or `op( X ) = X**T`, +/// - `alpha` and `beta` are scalars, and +/// - `A`, `B`, and `C` are matrices: +/// - `op( A )` is an `MxK` matrix, +/// - `op( B )` is an `KxN` matrix, +/// - `C` is an `MxN` matrix. +/// +/// The matrices are assumed to be stored in row-major order (the elements in +/// each of the matrix rows are contiguous in memory). +/// +/// @note +/// This API does not support XERBLA. Instead, unlike the standard BLAS +/// functions, this one returns a dnnl_status_t value to allow error +/// handling. +/// +/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not +/// transposed, and 'T' or 't' means that A is transposed. +/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not +/// transposed, and 'T' or 't' means that B is transposed. +/// @param M The M dimension. +/// @param N The N dimension. +/// @param K The K dimension. +/// @param alpha The alpha parameter that is used to scale the product of +/// matrices A and B. +/// @param A A pointer to the A matrix data. +/// @param lda The leading dimension for the matrix A. +/// @param B A pointer to the B matrix data. +/// @param ldb The leading dimension for the matrix B. +/// @param beta The beta parameter that is used to scale the matrix C. +/// @param C A pointer to the C matrix data. +/// @param ldc The leading dimension for the matrix C. +/// @returns #dnnl_success/#dnnl::status::success on success and a status +/// describing the error otherwise. +dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, + dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, + const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc); + +/// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit +/// signed matrix B, and 32-bit signed resulting matrix C. +/// +/// The operation is defined as: +/// +/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset` +/// +/// where +/// - `op( X ) = X` or `op( X ) = X**T`, +/// - `alpha` and `beta` are scalars, and +/// - `A`, `B`, and `C` are matrices: +/// - `op( A )` is an `MxK` matrix, +/// - `op( B )` is an `KxN` matrix, +/// - `C` is an `MxN` matrix. +/// - `A_offset` is an `MxK` matrix with every element equal the `ao` value, +/// - `B_offset` is an `KxN` matrix with every element equal the `bo` value, +/// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`: +/// - if `offsetc = F`: the `len` must be at least `1`, +/// - if `offsetc = C`: the `len` must be at least `max(1, m)`, +/// - if `offsetc = R`: the `len` must be at least `max(1, n)`, +/// +/// The matrices are assumed to be stored in row-major order (the elements in +/// each of the matrix rows are contiguous in memory). +/// +/// @note +/// This API does not support XERBLA. Instead, unlike the standard BLAS +/// functions, this one returns a dnnl_status_t value to allow error +/// handling. +/// +/// @warning +/// On some architectures saturation may happen during intermediate +/// computations, which would lead to unexpected results. For more +/// details, refer to @ref dev_guide_int8_computations. +/// +/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not +/// transposed, and 'T' or 't' means that A is transposed. +/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not +/// transposed, and 'T' or 't' means that B is transposed. +/// @param offsetc Flag specifying how offsets should be applied to matrix C: +/// - 'F' means that the same offset will be applied to each element of +/// the matrix C, +/// - 'C' means that individual offset will be applied to each element +/// within each column, +/// - 'R' means that individual offset will be applied to each element +/// within each row. +/// @param M The M dimension. +/// @param N The N dimension. +/// @param K The K dimension. +/// @param alpha The alpha parameter that is used to scale the product of +/// matrices A and B. +/// @param A A pointer to the A matrix data. +/// @param lda The leading dimension for the matrix A. +/// @param ao The offset value for the matrix A. +/// @param B A pointer to the B matrix data. +/// @param ldb The leading dimension for the matrix B. +/// @param bo The offset value for the matrix B. +/// @param beta The beta parameter that is used to scale the matrix C. +/// @param C A pointer to the C matrix data. +/// @param ldc The leading dimension for the matrix C. +/// @param co An array of offset values for the matrix C. The number of +/// elements in the array depends on the value of @p offsetc. +/// @returns #dnnl_success/#dnnl::status::success on success and a status +/// describing the error otherwise. +dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, + dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, + dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, + float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co); + +/// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit +/// signed matrix B, and 32-bit signed resulting matrix C. +/// +/// The operation is defined as: +/// +/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset` +/// +/// where +/// - `op( X ) = X` or `op( X ) = X**T`, +/// - `alpha` and `beta` are scalars, and +/// - `A`, `B`, and `C` are matrices: +/// - `op( A )` is an `MxK` matrix, +/// - `op( B )` is an `KxN` matrix, +/// - `C` is an `MxN` matrix. +/// - `A_offset` is an `MxK` matrix with every element equal the `ao` value, +/// - `B_offset` is an `KxN` matrix with every element equal the `bo` value, +/// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`: +/// - if `offsetc = F`: the `len` must be at least `1`, +/// - if `offsetc = C`: the `len` must be at least `max(1, m)`, +/// - if `offsetc = R`: the `len` must be at least `max(1, n)`, +/// +/// The matrices are assumed to be stored in row-major order (the elements in +/// each of the matrix rows are contiguous in memory). +/// +/// @note +/// This API does not support XERBLA. Instead, unlike the standard BLAS +/// functions, this one returns a dnnl_status_t value to allow error +/// handling. +/// +/// @warning +/// On some architectures saturation may happen during intermediate +/// computations, which would lead to unexpected results. For more +/// details, refer to @ref dev_guide_int8_computations. +/// +/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not +/// transposed, and 'T' or 't' means that A is transposed. +/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not +/// transposed, and 'T' or 't' means that B is transposed. +/// @param offsetc Flag specifying how offsets should be applied to matrix C: +/// - 'F' means that the same offset will be applied to each element of +/// the matrix C, +/// - 'C' means that individual offset will be applied to each element +/// within each column, +/// - 'R' means that individual offset will be applied to each element +/// within each row. +/// @param M The M dimension. +/// @param N The N dimension. +/// @param K The K dimension. +/// @param alpha The alpha parameter that is used to scale the product of +/// matrices A and B. +/// @param A A pointer to the A matrix data. +/// @param lda The leading dimension for the matrix A. +/// @param ao The offset value for the matrix A. +/// @param B A pointer to the B matrix data. +/// @param ldb The leading dimension for the matrix B. +/// @param bo The offset value for the matrix B. +/// @param beta The beta parameter that is used to scale the matrix C. +/// @param C A pointer to the C matrix data. +/// @param ldc The leading dimension for the matrix C. +/// @param co An array of offset values for the matrix C. The number of +/// elements in the array depends on the value of @p offsetc. +/// @returns #dnnl_success/#dnnl::status::success on success and a status +/// describing the error otherwise. +dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, + dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, + dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, + float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co); + +/// @} dnnl_api_blas + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif /* ONEAPI_DNNL_DNNL_H */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a3290fa87c67d6051273f2ae2887d0f9c1b7e96d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl.hpp @@ -0,0 +1,14066 @@ +/******************************************************************************* +* Copyright 2016-2025 Intel Corporation +* Copyright 2024 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C++ API + +#ifndef ONEAPI_DNNL_DNNL_HPP +#define ONEAPI_DNNL_DNNL_HPP + +#include "oneapi/dnnl/dnnl_config.h" + +/// @cond DO_NOT_DOCUMENT_THIS +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.h" +#include "oneapi/dnnl/dnnl_common.hpp" + +/// @endcond + +/// @addtogroup dnnl_api oneDNN API +/// @{ + +/// oneDNN namespace +namespace dnnl { + +/// @addtogroup dnnl_api_utils Utilities +/// Utility types and definitions. +/// @{ + +/// @cond DO_NOT_DOCUMENT_THIS +template +void validate_container_size(const T &v, const char *error_message, + int min_size = 1, int max_size = -1) { + const int size = (int)v.size(); + if (size < min_size || (max_size >= 0 && size > max_size)) + DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message); +} +/// @endcond + +/// @cond DO_NOT_DOCUMENT_THIS +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_memory_desc_t p) { + return dnnl_memory_desc_destroy(p); + } +}; + +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_memory_t p) { + return dnnl_memory_destroy(p); + } +}; + +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_primitive_desc_t p) { + return dnnl_primitive_desc_destroy(p); + } +}; + +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_primitive_t p) { + return dnnl_primitive_destroy(p); + } +}; + +/// @endcond + +/// @} dnnl_api_utils + +struct stream; +struct memory; +struct primitive_desc; + +/// @addtogroup dnnl_api_primitives Primitives +/// Compute primitives +/// @sa @ref dev_guide_basic_concepts +/// @{ + +/// @addtogroup dnnl_api_primitives_common Common +/// Common operations to create, destroy and inspect primitives +/// @{ + +/// Base class for all computational primitives. +struct primitive : public handle { + /// Kinds of primitives supported by the library. + enum class kind { + /// Undefined primitive + undef = dnnl_undefined_primitive, + /// A reorder primitive. + reorder = dnnl_reorder, + /// A shuffle primitive. + shuffle = dnnl_shuffle, + /// A (out-of-place) tensor concatenation primitive. + concat = dnnl_concat, + /// A summation primitive. + sum = dnnl_sum, + /// A convolution primitive. + convolution = dnnl_convolution, + /// A deconvolution primitive. + deconvolution = dnnl_deconvolution, + /// An element-wise primitive. + eltwise = dnnl_eltwise, + /// An LRN primitive. + lrn = dnnl_lrn, + /// A batch normalization primitive. + batch_normalization = dnnl_batch_normalization, + /// An inner product primitive. + inner_product = dnnl_inner_product, + /// An RNN primitive. + rnn = dnnl_rnn, + /// A binary primitive. + binary = dnnl_binary, + /// A matmul (matrix multiplication) primitive. + matmul = dnnl_matmul, + /// A resampling primitive. + resampling = dnnl_resampling, + /// A pooling primitive. + pooling = dnnl_pooling, + /// A reduction primitive. + reduction = dnnl_reduction, + /// A PReLU primitive. + prelu = dnnl_prelu, + /// A softmax primitive. + softmax = dnnl_softmax, + /// A layer normalization primitive. + layer_normalization = dnnl_layer_normalization, + /// A group normalization primitive + group_normalization = dnnl_group_normalization, + }; + + using handle::handle; + + /// Default constructor. Constructs an empty object. + primitive() = default; + + /// Constructs a primitive from a C API primitive descriptor. + /// + /// @param c_pd C API primitive descriptor. + primitive(const_dnnl_primitive_desc_t c_pd); + + /// Constructs a primitive from a C API primitive descriptor and a cache blob. + /// + /// @param c_pd C API primitive descriptor. + /// @param cache_blob Cache blob. + primitive(const_dnnl_primitive_desc_t c_pd, + const std::vector &cache_blob); + + /// Constructs a primitive from a primitive descriptor. + /// + /// @param pd Primitive descriptor. + primitive(const primitive_desc &pd); + + /// Constructs a primitive from a primitive descriptor and a cache blob. + /// + /// @param pd Primitive descriptor. + /// @param cache_blob Cache blob. + primitive(const primitive_desc &pd, const std::vector &cache_blob); + + /// Returns the C API primitive descriptor of the underlying C API + /// primitive. + /// + /// @returns The underlying C API primitive descriptor. + inline const_dnnl_primitive_desc_t get_primitive_desc() const; + + /// Returns the kind of the primitive. + /// + /// @returns The primitive kind. + inline kind get_kind() const; + + /// Returns a cache blob for the primitive. + /// + /// @returns Vector containing the cache blob. + /// + /// @note The cache blob can be empty. It's the user's responsibility to + /// check whether it's empty prior to passing it to the primitive + /// constructor. + inline std::vector get_cache_blob() const; + + /// Executes computations specified by the primitive in a specified stream. + /// + /// Arguments are passed via an arguments map containing pairs. The index must be one of the `DNNL_ARG_*` values + /// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor + /// matching the one returned by + /// primitive_desc::query_md(#query::exec_arg_md, index) unless using + /// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL). + /// + /// @param astream Stream object. The stream must belong to the same engine + /// as the primitive. + /// @param args Arguments map. + void execute(const stream &astream, + const std::unordered_map &args) const; +}; + +/// Converts primitive kind enum value from C++ API to C API type. +/// +/// @param akind C++ API primitive kind enum value. +/// @returns Corresponding C API primitive kind enum value. +inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) { + return static_cast(akind); +} + +const_dnnl_primitive_desc_t primitive::get_primitive_desc() const { + const_dnnl_primitive_desc_t pd; + error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd), + "could not get a primitive descriptor from a primitive"); + return pd; +} + +dnnl::primitive::kind primitive::get_kind() const { + const_dnnl_primitive_desc_t pd = get_primitive_desc(); + // TODO (Roma): the code below is only needed because get_primitive_desc + // returns a C type. + dnnl_primitive_kind_t kind; + error::wrap_c_api(dnnl_primitive_desc_query( + pd, dnnl_query_primitive_kind, 0, (void *)&kind), + "could not get a primitive kind from a primitive descriptor"); + return static_cast(kind); +} + +std::vector primitive::get_cache_blob() const { + size_t size; + error::wrap_c_api(dnnl_primitive_get_cache_blob(get(), &size, nullptr), + "could not get cache blob size from a primitive"); + + std::vector cache_blob(size); + error::wrap_c_api( + dnnl_primitive_get_cache_blob(get(), &size, cache_blob.data()), + "could not get a cache blob from a primitive"); + return cache_blob; +} + +/// @} dnnl_api_primitives_common + +/// @addtogroup dnnl_api_attributes +/// +/// A container for parameters that extend primitives behavior. +/// +/// Attributes can also contain Post-ops, which are computations executed +/// after the primitive. +/// +/// @sa @ref dev_guide_attributes +/// @sa @ref dev_guide_attributes_post_ops +/// +/// @{ + +/// Scratchpad mode +enum class scratchpad_mode { + /// The library manages the scratchpad allocation according to the policy + /// specified by the `DNNL_ENABLE_CONCURRENT_EXEC` + /// [build option](@ref dev_guide_build_options) (default). + /// + /// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library + /// scratchpad is common to all primitives to reduce the memory footprint. + /// This configuration comes with limited thread-safety properties, namely + /// primitives can be created and executed in parallel but cannot migrate + /// between threads (in other words, each primitive should be executed in + /// the same thread it was created in). + /// + /// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is + /// private to each primitive. The memory footprint is larger than when + /// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be + /// created and run concurrently (the same primitive cannot be run + /// concurrently from two different threads though). + library = dnnl_scratchpad_mode_library, + /// The user manages the scratchpad allocation by querying and providing + /// the scratchpad memory to primitives. This mode is thread-safe as long + /// as the scratchpad buffers are not used concurrently by two primitive + /// executions. + user = dnnl_scratchpad_mode_user, +}; + +/// Converts a scratchpad mode enum value from C++ API to C API type. +/// +/// @param mode C++ API scratchpad mode enum value. +/// @returns Corresponding C API scratchpad mode enum value. +inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { + return static_cast(mode); +} + +/// Rounding mode +enum class rounding_mode { + /// rounding mode dictated by the floating-point environment + environment = dnnl_rounding_mode_environment, + /// stochastic rounding mode where a random bias is added to the + /// trailing mantissa bits before conversion. + stochastic = dnnl_rounding_mode_stochastic +}; + +/// Converts a rounding mode enum value from C++ API to C API type. +/// +/// @param mode C++ API rounding mode enum value. +/// @returns Corresponding C API rounding mode enum value. +inline dnnl_rounding_mode_t convert_to_c(rounding_mode mode) { + return static_cast(mode); +} + +/// Propagation kind. +enum class prop_kind { + /// Undefined propagation kind. + undef = dnnl_prop_kind_undef, + /// Forward data propagation (training mode). In this mode, primitives + /// perform computations necessary for subsequent backward propagation. + forward_training = dnnl_forward_training, + /// Forward data propagation (inference mode). In this mode, primitives + /// perform only computations that are necessary for inference and omit + /// computations that are necessary only for backward propagation. + forward_inference = dnnl_forward_inference, + /// Forward data propagation, + /// alias for #dnnl::prop_kind::forward_training. + forward = dnnl_forward, + /// Backward propagation (with respect to all parameters). + backward = dnnl_backward, + /// Backward data propagation. + backward_data = dnnl_backward_data, + /// Backward weights propagation. + backward_weights = dnnl_backward_weights, + /// Backward bias propagation. + backward_bias = dnnl_backward_bias +}; + +/// Converts propagation kind enum value from C++ API to C API type. +/// +/// @param akind C++ API propagation kind enum value. +/// @returns Corresponding C API propagation kind enum value. +inline dnnl_prop_kind_t convert_to_c(prop_kind akind) { + return static_cast(akind); +} + +/// Kinds of algorithms. +enum class algorithm { + /// Undefined algorithm + undef = dnnl_alg_kind_undef, + /// Convolution algorithm that is chosen to be either direct or Winograd + /// automatically + convolution_auto = dnnl_convolution_auto, + /// Direct convolution + convolution_direct = dnnl_convolution_direct, + /// Winograd convolution + convolution_winograd = dnnl_convolution_winograd, + /// Direct deconvolution + deconvolution_direct = dnnl_deconvolution_direct, + /// Winograd deconvolution + deconvolution_winograd = dnnl_deconvolution_winograd, + /// Elementwise: rectified linear unit (ReLU) + eltwise_relu = dnnl_eltwise_relu, + /// Elementwise: hyperbolic tangent non-linearity (tanh) + eltwise_tanh = dnnl_eltwise_tanh, + /// Elementwise: exponential linear unit (ELU) + eltwise_elu = dnnl_eltwise_elu, + /// Elementwise: square + eltwise_square = dnnl_eltwise_square, + /// Elementwise: abs + eltwise_abs = dnnl_eltwise_abs, + /// Elementwise: square root + eltwise_sqrt = dnnl_eltwise_sqrt, + /// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$) + eltwise_swish = dnnl_eltwise_swish, + /// Elementwise: linear + eltwise_linear = dnnl_eltwise_linear, + /// Elementwise: soft_relu + eltwise_soft_relu = dnnl_eltwise_soft_relu, + /// Elementwise: mish + eltwise_mish = dnnl_eltwise_mish, + /// Elementwise: logistic + eltwise_logistic = dnnl_eltwise_logistic, + /// Elementwise: exponent + eltwise_exp = dnnl_eltwise_exp, + /// Elementwise: tanh-based gelu + eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh, + /// Elementwise: erf-based gelu + eltwise_gelu_erf = dnnl_eltwise_gelu_erf, + /// Elementwise: natural logarithm + eltwise_log = dnnl_eltwise_log, + /// Elementwise: clip + eltwise_clip = dnnl_eltwise_clip, + /// Eltwise: clip version 2 + eltwise_clip_v2 = dnnl_eltwise_clip_v2, + /// Elementwise: pow + eltwise_pow = dnnl_eltwise_pow, + /// Elementwise: round + eltwise_round = dnnl_eltwise_round, + /// Elementwise: hardswish + eltwise_hardswish = dnnl_eltwise_hardswish, + /// Elementwise: hardsigmoid + eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid, + /// Elementwise: rectified linar unit (ReLU) (dst for backward) + eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd, + /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward) + eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd, + /// Elementwise: exponential linear unit (ELU) (dst for backward) + eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd, + /// Elementwise: square root (dst for backward) + eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd, + /// Elementwise: logistic (dst for backward) + eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd, + /// Elementwise: exponent (dst for backward) + eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd, + /// Elementwise: clip version 2 (dst for backward) + eltwise_clip_v2_use_dst_for_bwd = dnnl_eltwise_clip_v2_use_dst_for_bwd, + /// Local response normalization (LRN) across multiple channels + lrn_across_channels = dnnl_lrn_across_channels, + /// LRN within a single channel + lrn_within_channel = dnnl_lrn_within_channel, + /// Max pooling + pooling_max = dnnl_pooling_max, + /// Average pooling include padding + pooling_avg_include_padding = dnnl_pooling_avg_include_padding, + /// Average pooling exclude padding + pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding, + /// RNN cell + vanilla_rnn = dnnl_vanilla_rnn, + /// LSTM cell + vanilla_lstm = dnnl_vanilla_lstm, + /// GRU cell + vanilla_gru = dnnl_vanilla_gru, + /// GRU cell with linear before reset. Differs from the vanilla GRU + /// in how the new memory gate is calculated: + /// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$ + /// LRB GRU expects 4 bias tensors on input: + /// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ + lbr_gru = dnnl_lbr_gru, + /// AUGRU cell + vanilla_augru = dnnl_vanilla_augru, + /// AUGRU cell with linear before reset + lbr_augru = dnnl_lbr_augru, + /// Binary add + binary_add = dnnl_binary_add, + /// Binary mul + binary_mul = dnnl_binary_mul, + /// Binary max + binary_max = dnnl_binary_max, + /// Binary min + binary_min = dnnl_binary_min, + /// Binary div + binary_div = dnnl_binary_div, + /// Binary sub + binary_sub = dnnl_binary_sub, + /// Binary greater than or equal + binary_ge = dnnl_binary_ge, + /// Binary greater than + binary_gt = dnnl_binary_gt, + /// Binary less than or equal + binary_le = dnnl_binary_le, + /// Binary less than + binary_lt = dnnl_binary_lt, + /// Binary equal + binary_eq = dnnl_binary_eq, + /// Binary not equal + binary_ne = dnnl_binary_ne, + /// Binary select + binary_select = dnnl_binary_select, + /// Nearest Neighbor resampling method + resampling_nearest = dnnl_resampling_nearest, + /// Linear (Bilinear, Trilinear) resampling method + resampling_linear = dnnl_resampling_linear, + /// Reduction using max operation + reduction_max = dnnl_reduction_max, + /// Reduction using min operation + reduction_min = dnnl_reduction_min, + /// Reduction using sum operation + reduction_sum = dnnl_reduction_sum, + /// Reduction using mul operation + reduction_mul = dnnl_reduction_mul, + /// Reduction using mean operation + reduction_mean = dnnl_reduction_mean, + /// Reduction using norm_lp_max operation + reduction_norm_lp_max = dnnl_reduction_norm_lp_max, + /// Reduction using norm_lp_sum operation + reduction_norm_lp_sum = dnnl_reduction_norm_lp_sum, + /// Reduction using norm_lp_power_p_max operation + reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max, + /// Reduction using norm_lp_power_p_sum operation + reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum, + /// Softmax, numerically stable + softmax_accurate = dnnl_softmax_accurate, + /// LogSoftmax, numerically stable + softmax_log = dnnl_softmax_log, +}; + +/// Converts algorithm kind enum value from C++ API to C API type. +/// @param aalgorithm C++ API algorithm kind enum value. +/// @returns Corresponding C API algorithm kind enum value. +inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) { + return static_cast(aalgorithm); +} + +/// @} dnnl_api_attributes + +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// Flags for normalization primitives. +enum class normalization_flags : unsigned { + /// Use no normalization flags. If specified, the library computes mean and + /// variance on forward propagation for training and inference, outputs them + /// on forward propagation for training, and computes the respective + /// derivatives on backward propagation. + none = dnnl_normalization_flags_none, + + /// Use global statistics. If specified, the library uses mean and + /// variance provided by the user as an input on forward propagation and + /// does not compute their derivatives on backward propagation. Otherwise, + /// the library computes mean and variance on forward propagation for + /// training and inference, outputs them on forward propagation for + /// training, and computes the respective derivatives on backward + /// propagation. + use_global_stats = dnnl_use_global_stats, + + /// Use scale parameter. If specified, the user is expected to pass scale as + /// input on forward propagation. On backward propagation of type + /// #dnnl::prop_kind::backward, the library computes its derivative. + use_scale = dnnl_use_scale, + + /// Use shift parameter. If specified, the user is expected to pass shift as + /// input on forward propagation. On backward propagation of type + /// #dnnl::prop_kind::backward, the library computes its derivative. + use_shift = dnnl_use_shift, + + /// Fuse normalization with ReLU. On training, normalization will require + /// the workspace to implement backward propagation. On inference, the + /// workspace is not required and behavior is the same as when normalization + /// is fused with ReLU using the post-ops API. + fuse_norm_relu = dnnl_fuse_norm_relu, + + /// Fuse normalization with elementwise binary Add and then fuse with ReLU. + /// On training, normalization will require the workspace to implement + /// backward propagation. On inference, the workspace is not required. + fuse_norm_add_relu = dnnl_fuse_norm_add_relu, +}; + +/// Converts normalization flags enum value from C++ API to C API type. +/// @param flags C++ API normalization flags enum value. +/// @returns Corresponding C API normalization flags enum value. +inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) { + return static_cast(flags); +} + +/// @} dnnl_api_primitives_common + +/// @addtogroup dnnl_api_rnn +/// @{ + +/// RNN cell flags. +enum class rnn_flags : unsigned { + /// Undefined RNN flags + undef = dnnl_rnn_flags_undef, + /// Do not add weights gradient to existing diff_weights memory + diff_weights_overwrite = dnnl_rnn_flags_diff_weights_overwrite, +}; + +/// Converts RNN cell flags enum value from C++ API to C API type. +/// @param flags C++ API RNN cell flags enum value. +/// @returns Corresponding C API RNN cell flags enum value. +inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) { + return static_cast(flags); +} + +DNNL_DEFINE_BITMASK_OPS(normalization_flags) +DNNL_DEFINE_BITMASK_OPS(rnn_flags) + +/// A direction of RNN primitive execution +enum class rnn_direction { + /// Undefined RNN direction. + undef = dnnl_rnn_direction_undef, + /// Unidirectional execution of RNN primitive from left to right. + unidirectional_left2right = dnnl_unidirectional_left2right, + /// Unidirectional execution of RNN primitive from right to left. + unidirectional_right2left = dnnl_unidirectional_right2left, + /// Bidirectional execution of RNN primitive with concatenation of the + /// results. + bidirectional_concat = dnnl_bidirectional_concat, + /// Bidirectional execution of RNN primitive with summation of the + /// results. + bidirectional_sum = dnnl_bidirectional_sum, +}; + +/// Converts RNN direction enum value from C++ API to C API type. +/// @param dir C++ API RNN direction enum value. +/// @returns Corresponding C API RNN direction enum value. +inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) { + return static_cast(dir); +} + +/// @} dnnl_api_rnn + +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// Primitive descriptor query specification. +/// +/// In general, queries are not used with the C++ API because most queries are +/// implemented as class members. +/// +/// See @ref dnnl_query_t for more information. +enum class query { + /// no query + undef = dnnl_query_undef, + + /// execution engine + engine = dnnl_query_engine, + /// primitive kind + primitive_kind = dnnl_query_primitive_kind, + + /// number of inputs expected + num_of_inputs_s32 = dnnl_query_num_of_inputs_s32, + /// number of outputs expected + num_of_outputs_s32 = dnnl_query_num_of_outputs_s32, + + /// runtime estimation (seconds), unimplemented + time_estimate_f64 = dnnl_query_time_estimate_f64, + /// memory required for scratchpad (bytes) + /// + /// @sa @ref dev_guide_attributes_scratchpad + memory_consumption_s64 = dnnl_query_memory_consumption_s64, + + /// scratchpad engine + /// + /// engine to be used for creating scratchpad memory + scratchpad_engine = dnnl_query_scratchpad_engine, + + /// reorder source engine + reorder_src_engine = dnnl_query_reorder_src_engine, + /// reorder destination engine + reorder_dst_engine = dnnl_query_reorder_dst_engine, + + /// implementation name + impl_info_str = dnnl_query_impl_info_str, + + /// propagation kind + prop_kind = dnnl_query_prop_kind, + + /// size of cache blob ID in bytes + cache_blob_id_size_s64 = dnnl_query_cache_blob_id_size_s64, + + /// cache blob ID (pointer to array) + cache_blob_id = dnnl_query_cache_blob_id, + + /// strides + strides = dnnl_query_strides, + /// dilations + dilations = dnnl_query_dilations, + /// left padding + padding_l = dnnl_query_padding_l, + /// right padding + padding_r = dnnl_query_padding_r, + /// epsilon + epsilon_f32 = dnnl_query_epsilon_f32, + /// flags + flags = dnnl_query_flags, + /// algorithm kind + alg_kind = dnnl_query_alg_kind, + /// alpha + alpha_f32 = dnnl_query_alpha_f32, + /// beta + beta_f32 = dnnl_query_beta_f32, + /// axis + axis_s32 = dnnl_query_axis_s32, + /// LRN parameter local size + local_size_s64 = dnnl_query_local_size_s64, + /// LRN parameter K + k_f32 = dnnl_query_k_f32, + /// Reduction parameter P + p_f32 = dnnl_query_p_f32, + /// Resampling parameter factors + factors = dnnl_query_factors, + /// RNN parameter cell kind + cell_kind = dnnl_query_cell_kind, + /// RNN parameter direction + direction = dnnl_query_direction, + /// RNN parameter activation kind + activation_kind = dnnl_query_activation_kind, + /// Pooling parameter kernel + kernel = dnnl_query_kernel, + /// Shuffle parameter group size + group_size_s64 = dnnl_query_group_size_s64, + + /// source memory desc + src_md = dnnl_query_src_md, + /// source gradient (diff) memory desc + diff_src_md = dnnl_query_diff_src_md, + /// weights memory descriptor desc + weights_md = dnnl_query_weights_md, + /// weights gradient (diff) memory desc + diff_weights_md = dnnl_query_diff_weights_md, + /// destination memory desc + dst_md = dnnl_query_dst_md, + /// destination gradient (diff) memory desc + diff_dst_md = dnnl_query_diff_dst_md, + /// workspace memory desc + workspace_md = dnnl_query_workspace_md, + /// scratchpad memory desc + scratchpad_md = dnnl_query_scratchpad_md, + /// memory desc of an execute argument + exec_arg_md = dnnl_query_exec_arg_md, + + /// number of dimensions + ndims_s32 = dnnl_query_ndims_s32, + /// vector of dimensions + dims = dnnl_query_dims, + /// data type + data_type = dnnl_query_data_type, + /// submemory offset + submemory_offset_s64 = dnnl_query_submemory_offset_s64, + /// vector of padded dimensions + padded_dims = dnnl_query_padded_dims, + /// vector of padded offsets + padded_offsets = dnnl_query_padded_offsets, + /// format kind + format_kind = dnnl_query_format_kind, + /// number of innermost blocks + inner_nblks_s32 = dnnl_query_inner_nblks_s32, + /// vector of sizes of the innermost blocks + inner_blks = dnnl_query_inner_blks, + /// vector of logical indices of the blocks + inner_idxs = dnnl_query_inner_idxs, +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Sparse encoding + sparse_encoding = dnnl_query_sparse_encoding, + /// Number of non-zero entries + nnz_s64 = dnnl_query_nnz_s64, + /// Number of buffers required for a memory descriptor + num_handles_s32 = dnnl_query_num_handles_s32, +#endif +}; + +/// Converts query enum value from C++ API to C API type. +/// @param aquery C++ API query enum value. +/// @returns Corresponding C API query enum value. +inline dnnl_query_t convert_to_c(query aquery) { + return static_cast(aquery); +} + +/// @} dnnl_api_primitives_common + +/// @} dnnl_api_primitives + +/// @addtogroup dnnl_api_memory Memory +/// +/// A container that describes and stores data. Memory objects can contain +/// data of various types and formats. There are two levels of abstraction: +/// +/// 1. **Memory descriptor** -- engine-agnostic logical description of data +/// (number of dimensions, dimension sizes, and data type), and, +/// optionally, the information about the physical format of data in +/// memory. If this information is not known yet, a memory descriptor can +/// be created with #dnnl::memory::format_tag::any. This allows +/// compute-intensive primitives to choose the best format for +/// computation. The user is responsible for reordering the data into the +/// chosen format when formats do not match. +/// +/// A memory descriptor can be initialized either by specifying dimensions +/// and a memory format tag or strides for each of them, or by +/// manipulating the dnnl_memory_desc_t structure directly. +/// +/// @warning +/// The latter approach requires understanding how the physical data +/// representation is mapped to the structure and is discouraged. This +/// topic is discussed in @ref dev_guide_understanding_memory_formats. +/// +/// The user can query the amount of memory required by a memory +/// descriptor using the #dnnl::memory::desc::get_size() function. The +/// size of data in general cannot be computed as the product of +/// dimensions multiplied by the size of the data type. So users are +/// required to use this function for better code portability. +/// +/// Two memory descriptors can be compared using the equality and +/// inequality operators. The comparison is especially useful when +/// checking whether it is necessary to reorder data from the user's data +/// format to a primitive's format. +/// +/// 2. **Memory object** -- an engine-specific object that handles the memory +/// buffer and its description (a memory descriptor). For the CPU engine or +/// with USM, the memory buffer handle is simply a pointer to @c void. The +/// memory buffer can be queried using #dnnl::memory::get_data_handle() and +/// set using #dnnl::memory::set_data_handle(). The underlying SYCL buffer, +/// when used, can be queried using #dnnl::sycl_interop::get_buffer and set +/// using #dnnl::sycl_interop::set_buffer. A memory object can also be +/// queried for the underlying memory descriptor and for its engine using +/// #dnnl::memory::get_desc() and dnnl::memory::get_engine(). +/// +/// Along with ordinary memory descriptors with all dimensions being positive, +/// the library supports *zero-volume* memory descriptors with one or more +/// dimensions set to zero. This is used to support the NumPy\* convention. +/// If a zero-volume memory is passed to a primitive, the primitive typically +/// does not perform any computations with this memory. For example: +/// +/// - A concatenation primitive would ignore all memory object with zeroes in +/// the concat dimension / axis. +/// +/// - A forward convolution with a source memory object with zero in the +/// minibatch dimension would always produce a destination memory object +/// with a zero in the minibatch dimension and perform no computations. +/// +/// - However, a forward convolution with a zero in one of the weights +/// dimensions is ill-defined and is considered to be an error by the +/// library because there is no clear definition of what the output values +/// should be. +/// +/// Memory buffer of a zero-volume memory is never accessed. +/// +/// @{ + +/// Memory object. +/// +/// A memory object encapsulates a handle to a memory buffer allocated on a +/// specific engine, tensor dimensions, data type, and memory format, which is +/// the way tensor indices map to offsets in linear memory space. Memory +/// objects are passed to primitives during execution. +struct memory : public handle { + using handle::handle; + + /// Integer type for representing dimension sizes and indices. + typedef dnnl_dim_t dim; + /// Vector of dimensions. Implementations are free to force a limit on the + /// vector's length. + typedef std::vector dims; + + /// Helper function that validates that an `std::vector` of dimensions can + /// be safely converted to the C API array ::dnnl_dims_t. Throws if + /// validation fails. + /// + /// @param v Vector of dimensions. + /// @param min_size Minimum expected size of the vector. + template + static void validate_dims(const std::vector &v, int min_size = 0) { + validate_container_size( + v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS); + } + + /// Data type specification. + enum class data_type { + /// Undefined data type (used for empty memory descriptors). + undef = dnnl_data_type_undef, + /// 4-bit float data type with 3-bit exponent and 0 bit mantissa. + f4_e3m0 = dnnl_f4_e3m0, + /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa. + f4_e2m1 = dnnl_f4_e2m1, + /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent. + e8m0 = dnnl_e8m0, + /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) + /// with a 5-bit exponent and a 2-bit mantissa. + f8_e5m2 = dnnl_f8_e5m2, + /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) + /// with a 4-bit exponent and a 3-bit mantissa. + f8_e4m3 = dnnl_f8_e4m3, + /// [16-bit/half-precision floating point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format). + f16 = dnnl_f16, + /// non-standard + /// [16-bit floating point with 7-bit mantissa](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format). + bf16 = dnnl_bf16, + /// [32-bit/single-precision floating point](https://en.wikipedia.org/wiki/Single-precision_floating-point_format). + f32 = dnnl_f32, + //// [64-bit/double-precision floating point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format). + f64 = dnnl_f64, + /// 32-bit signed integer. + s32 = dnnl_s32, + /// 8-bit signed integer. + s8 = dnnl_s8, + /// 8-bit unsigned integer. + u8 = dnnl_u8, + /// 4-bit signed integer. + s4 = dnnl_s4, + /// 4-bit unsigned integer. + u4 = dnnl_u4, + }; + + /// Returns size of data type in bytes. + /// @returns The number of bytes occupied by data type. + static size_t data_type_size(data_type adata_type) { + return dnnl_data_type_size(convert_to_c(adata_type)); + } + + /// Memory format kind + enum class format_kind { + /// Undefined memory format kind, used for empty memory descriptors. + undef = dnnl_format_kind_undef, + /// A special format kind that indicates that the actual format will be + /// selected by a primitive automatically. + any = dnnl_format_kind_any, + /// A tensor in a generic format described by the stride and blocking + /// values in each dimension. + blocked = dnnl_blocked, +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Format kind for sparse tensors. + sparse = dnnl_format_kind_sparse, +#endif + /// A special format kind that indicates that tensor format is opaque. + opaque = dnnl_format_kind_opaque, + }; + +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Sparse encodings. + enum class sparse_encoding { + /// Undefined sparse encoding kind, used for empty memory descriptors. + undef = dnnl_sparse_encoding_undef, + /// Compressed Sparse Row (CSR) encoding. + csr = dnnl_csr, + /// An encoding that is used for an opaque storage schema for + /// tensors with unstructured sparsity. A memory descriptor with the + /// packed encoding cannot be used to create a memory object. It can + /// only be used to create a primitive descriptor to query the + /// actual memory descriptor (similar to the format tag `any`). + packed = dnnl_packed, + /// Coordinate Sparse (COO) encoding. + coo = dnnl_coo, + }; +#endif + + /// Memory format tag specification. + /// + /// Memory format tags can be further divided into two categories: + /// + /// - Domain-agnostic names, i.e. names that do not depend on the tensor + /// usage in the specific primitive. These names use letters from `a` + /// to `f` to denote logical dimensions and form the order in which the + /// dimensions are laid in memory. For example, + /// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the + /// second logical dimension (denoted as `b`) is the innermost, i.e. + /// has stride = 1, and the first logical dimension (`a`) is laid out in + /// memory with stride equal to the size of the second dimension. On the + /// other hand, #dnnl::memory::format_tag::ba is the transposed version + /// of the same tensor: the outermost dimension (`a`) becomes the + /// innermost one. + /// + /// - Domain-specific names, i.e. names that make sense only in the + /// context of a certain domain, such as CNN. These names are + /// aliases to the corresponding domain-agnostic tags and used mostly + /// for convenience. For example, #dnnl::memory::format_tag::nc + /// is used to denote 2D CNN activations tensor memory format, where + /// the channels dimension is the innermost one and the batch dimension + /// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is + /// an alias for #dnnl::memory::format_tag::ab, because for + /// CNN primitives the logical dimensions of activations tensors come + /// in order: batch, channels, spatial. In other words, batch + /// corresponds to the first logical dimension (`a`), and channels + /// correspond to the second one (`b`). + /// + /// The following domain-specific notation applies to memory format tags: + /// - @c 'n' denotes the mini-batch dimension + /// - @c 'c' denotes a channels dimension + /// - When there are multiple channel dimensions (for example, + /// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions + /// of input and output channels + /// - @c 'g' denotes a groups dimension for convolution weights + /// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width + /// respectively + /// + /// See @ref dnnl_format_tag_t for a detailed description. + enum class format_tag { + /// Undefined memory format tag + undef = dnnl_format_tag_undef, + /// Placeholder memory format tag. Used to instruct the primitive to + /// select a format automatically. + any = dnnl_format_tag_any, + + /// plain 1D tensor + a = dnnl_a, + + /// plain 2D tensor + ab = dnnl_ab, + /// permuted 2D tensor + ba = dnnl_ba, + + /// plain 3D tensor + abc = dnnl_abc, + /// permuted 3D tensor + acb = dnnl_acb, + /// permuted 3D tensor + bac = dnnl_bac, + /// permuted 3D tensor + bca = dnnl_bca, + /// permuted 3D tensor + cba = dnnl_cba, + + /// plain 4D tensor + abcd = dnnl_abcd, + /// permuted 4D tensor + abdc = dnnl_abdc, + /// permuted 4D tensor + acbd = dnnl_acbd, + /// permuted 4D tensor + acdb = dnnl_acdb, + /// permuted 4D tensor + adbc = dnnl_adbc, + /// permuted 4D tensor + bacd = dnnl_bacd, + /// permuted 4D tensor + bcda = dnnl_bcda, + /// permuted 4D tensor + cdba = dnnl_cdba, + /// permuted 4D tensor + dcab = dnnl_dcab, + + /// plain 5D tensor + abcde = dnnl_abcde, + /// permuted 5D tensor + abdec = dnnl_abdec, + /// permuted 5D tensor + acbde = dnnl_acbde, + /// permuted 5D tensor + acdeb = dnnl_acdeb, + /// permuted 5D tensor + bacde = dnnl_bacde, + /// permuted 5D tensor + bcdea = dnnl_bcdea, + /// permuted 5D tensor + cdeba = dnnl_cdeba, + /// permuted 5D tensor + decab = dnnl_decab, + /// permuted 5D tensor + abced = dnnl_abced, + + /// plain 6D tensor + abcdef = dnnl_abcdef, + /// permuted 6D tensor + abdfce = dnnl_abdfce, + /// permuted 6D tensor + acbdef = dnnl_acbdef, + /// permuted 6D tensor + abdefc = dnnl_abdefc, + /// permuted 6D tensor + defcab = dnnl_defcab, + /// permuted 6D tensor + abcdfe = dnnl_abcdfe, + + /// plain 7D tensor + abcdefg = dnnl_abcdefg, + /// permuted 7D tensor + abcdegf = dnnl_abcdegf, + + /// plain 8D tensor + abcdefgh = dnnl_abcdefgh, + /// permuted 8D tensor + abcdefhg = dnnl_abcdefhg, + + /// plain 9D tensor + abcdefghi = dnnl_abcdefghi, + /// permuted 9D tensor + abcdefgih = dnnl_abcdefgih, + + /// plain 10D tensor + abcdefghij = dnnl_abcdefghij, + /// permuted 10D tensor + abcdefghji = dnnl_abcdefghji, + + /// plain 11D tensor + abcdefghijk = dnnl_abcdefghijk, + /// permuted 11D tensor + abcdefghikj = dnnl_abcdefghikj, + + /// plain 12D tensor + abcdefghijkl = dnnl_abcdefghijkl, + /// permuted 12D tensor + abcdefghijlk = dnnl_abcdefghijlk, + + /// 1D tensor; an alias for #dnnl::memory::format_tag::a + x = a, + /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab + nc = ab, + /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba + cn = ba, + /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab + tn = ab, + /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba + nt = ba, + /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc + ncw = abc, + /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb + nwc = acb, + /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd + nchw = abcd, + /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb + nhwc = acdb, + /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda + chwn = bcda, + /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde + ncdhw = abcde, + /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb + ndhwc = acdeb, + + /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab + oi = ab, + /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba + io = ba, + /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc + oiw = abc, + /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb + owi = acb, + /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba + wio = cba, + /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca + iwo = bca, + /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd + oihw = abcd, + /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba + hwio = cdba, + /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb + ohwi = acdb, + /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda + ihwo = bcda, + /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd + iohw = bacd, + /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde + oidhw = abcde, + /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba + dhwio = cdeba, + /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb + odhwi = acdeb, + /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacde + iodhw = bacde, + /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea + idhwo = bcdea, + + /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd + goiw = abcd, + /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdc + gowi = abdc, + /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab + wigo = dcab, + /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdec + gohwi = abdec, + /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde + goihw = abcde, + /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab + hwigo = decab, + /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde + giohw = acbde, + /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef + goidhw = abcdef, + /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef + giodhw = acbdef, + /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdefc + godhwi = abdefc, + /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab + dhwigo = defcab, + + /// 3D RNN data tensor in the format (seq_length, batch, input + /// channels); an alias for #dnnl::memory::format_tag::abc. + tnc = abc, + /// 3D RNN data tensor in the format (batch, seq_length, input + /// channels); an alias for #dnnl::memory::format_tag::bac. + ntc = bac, + /// 4D RNN states tensor in the format (num_layers, num_directions, + /// batch, state channels); an alias for #dnnl::memory::format_tag::abcd. + ldnc = abcd, + /// 5D RNN weights tensor in the format (num_layers, num_directions, + /// input_channels, num_gates, output_channels); + /// an alias for #dnnl::memory::format_tag::abcde. + /// + /// - For LSTM cells, the gates order is input, forget, candidate + /// and output gate. + /// - For GRU cells, the gates order is update, reset and output gate. + ldigo = abcde, + /// 5D RNN weights tensor in the format (num_layers, num_directions, + /// num_gates, output_channels, input_channels); + /// an alias for #dnnl::memory::format_tag::abdec. + /// + /// - For LSTM cells, the gates order is input, forget, candidate + /// and output gate. + /// - For GRU cells, the gates order is update, reset and output gate. + ldgoi = abdec, + /// 4D LSTM projection tensor in the format (num_layers, num_directions, + /// num_channels_in_hidden_state, num_channels_in_recurrent_projection); + /// an alias for #dnnl::memory::format_tag::abcd. + ldio = abcd, + /// 4D LSTM projection tensor in the format (num_layers, num_directions, + /// num_channels_in_recurrent_projection, num_channels_in_hidden_state); + /// an alias for #dnnl::memory::format_tag::abdc. + ldoi = abdc, + /// 4D RNN bias tensor in the format (num_layers, num_directions, + /// num_gates, output_channels); + /// an alias for #dnnl::memory::format_tag::abcd. + /// + /// - For LSTM cells, the gates order is input, forget, candidate + /// and output gate. + /// - For GRU cells, the gates order is update, reset and output gate. + ldgo = abcd, + + // Opaque blocked formats + + AB16b16a = dnnl_AB16b16a, + AB16b32a = dnnl_AB16b32a, + AB16b48a = dnnl_AB16b48a, + AB16b64a = dnnl_AB16b64a, + AB8b16a2b = dnnl_AB8b16a2b, + AB8b32a2b = dnnl_AB8b32a2b, + AB8b64a2b = dnnl_AB8b64a2b, + AB4b16a4b = dnnl_AB4b16a4b, + AB4b32a4b = dnnl_AB4b32a4b, + AB4b64a4b = dnnl_AB4b64a4b, + AB16b16a4b = dnnl_AB16b16a4b, + AB16b32a4b = dnnl_AB16b32a4b, + AB16b48a4b = dnnl_AB16b48a4b, + AB16b64a4b = dnnl_AB16b64a4b, + AB16b16a2b = dnnl_AB16b16a2b, + AB16b32a2b = dnnl_AB16b32a2b, + AB16b48a2b = dnnl_AB16b48a2b, + AB16b64a2b = dnnl_AB16b64a2b, + Ab4a = dnnl_Ab4a, + Ab8a = dnnl_Ab8a, + Ab32a = dnnl_Ab32a, + Abc16a = dnnl_Abc16a, + ABc16a16b = dnnl_ABc16a16b, + ABc4a4b = dnnl_ABc4a4b, + aBc16b = dnnl_aBc16b, + aBc32b = dnnl_aBc32b, + ABc16b16a = dnnl_ABc16b16a, + AcB16b16a = dnnl_AcB16b16a, + ABc16b32a = dnnl_ABc16b32a, + AcB16b32a = dnnl_AcB16b32a, + ABc16b48a = dnnl_ABc16b48a, + AcB16b48a = dnnl_AcB16b48a, + ABc16b64a = dnnl_ABc16b64a, + AcB16b64a = dnnl_AcB16b64a, + Abc4a = dnnl_Abc4a, + aBc4b = dnnl_aBc4b, + ABc4b16a4b = dnnl_ABc4b16a4b, + AcB4b16a4b = dnnl_AcB4b16a4b, + ABc4b32a4b = dnnl_ABc4b32a4b, + AcB4b32a4b = dnnl_AcB4b32a4b, + ABc4b64a4b = dnnl_ABc4b64a4b, + AcB4b64a4b = dnnl_AcB4b64a4b, + ABc2b8a4b = dnnl_ABc2b8a4b, + ABc16a16b2a = dnnl_ABc16a16b2a, + ABc16b16a4b = dnnl_ABc16b16a4b, + ABc16b32a4b = dnnl_ABc16b32a4b, + ABc16b48a4b = dnnl_ABc16b48a4b, + ABc16b64a4b = dnnl_ABc16b64a4b, + ABc16b16a2b = dnnl_ABc16b16a2b, + ABc16b32a2b = dnnl_ABc16b32a2b, + ABc16b48a2b = dnnl_ABc16b48a2b, + ABc16b64a2b = dnnl_ABc16b64a2b, + ABc4b4a = dnnl_ABc4b4a, + ABc8a16b2a = dnnl_ABc8a16b2a, + ABc8a8b = dnnl_ABc8a8b, + ABc8a4b = dnnl_ABc8a4b, + aBc8b = dnnl_aBc8b, + ABc8b16a2b = dnnl_ABc8b16a2b, + AcB8b16a2b = dnnl_AcB8b16a2b, + ABc8b32a2b = dnnl_ABc8b32a2b, + AcB8b32a2b = dnnl_AcB8b32a2b, + ABc8b64a2b = dnnl_ABc8b64a2b, + AcB8b64a2b = dnnl_AcB8b64a2b, + ABc8b8a = dnnl_ABc8b8a, + AcB8b8a = dnnl_AcB8b8a, + Abcd8a = dnnl_Abcd8a, + Abcd16a = dnnl_Abcd16a, + Abcd32a = dnnl_Abcd32a, + ABcd16a16b = dnnl_ABcd16a16b, + aBcd16b = dnnl_aBcd16b, + aBcd32b = dnnl_aBcd32b, + ABcd16b16a = dnnl_ABcd16b16a, + AcdB16b16a = dnnl_AcdB16b16a, + ABcd16b32a = dnnl_ABcd16b32a, + AcdB16b32a = dnnl_AcdB16b32a, + ABcd16b48a = dnnl_ABcd16b48a, + AcdB16b48a = dnnl_AcdB16b48a, + ABcd16b64a = dnnl_ABcd16b64a, + AcdB16b64a = dnnl_AcdB16b64a, + aBCd16b16c = dnnl_aBCd16b16c, + aBCd16c16b = dnnl_aBCd16c16b, + Abcd4a = dnnl_Abcd4a, + aBcd4b = dnnl_aBcd4b, + ABcd4b16a4b = dnnl_ABcd4b16a4b, + AcdB4b16a4b = dnnl_AcdB4b16a4b, + ABcd4b32a4b = dnnl_ABcd4b32a4b, + AcdB4b32a4b = dnnl_AcdB4b32a4b, + ABcd4b64a4b = dnnl_ABcd4b64a4b, + AcdB4b64a4b = dnnl_AcdB4b64a4b, + ABcd2b8a4b = dnnl_ABcd2b8a4b, + ABcd4b4a = dnnl_ABcd4b4a, + ABcd4a4b = dnnl_ABcd4a4b, + aBCd4c16b4c = dnnl_aBCd4c16b4c, + aBCd2c8b4c = dnnl_aBCd2c8b4c, + ABcd16a16b2a = dnnl_ABcd16a16b2a, + ABcd16b16a4b = dnnl_ABcd16b16a4b, + ABcd16b32a4b = dnnl_ABcd16b32a4b, + ABcd16b48a4b = dnnl_ABcd16b48a4b, + ABcd16b64a4b = dnnl_ABcd16b64a4b, + ABcd16b16a2b = dnnl_ABcd16b16a2b, + ABcd16b32a2b = dnnl_ABcd16b32a2b, + ABcd16b48a2b = dnnl_ABcd16b48a2b, + ABcd16b64a2b = dnnl_ABcd16b64a2b, + aBCd16b16c2b = dnnl_aBCd16b16c2b, + aBCd16c16b4c = dnnl_aBCd16c16b4c, + aBCd16c16b2c = dnnl_aBCd16c16b2c, + aBCd4c4b = dnnl_aBCd4c4b, + aBCd4b4c = dnnl_aBCd4b4c, + ABcd8a16b2a = dnnl_ABcd8a16b2a, + ABcd8a8b = dnnl_ABcd8a8b, + ABcd8a4b = dnnl_ABcd8a4b, + ABcd8a2b = dnnl_ABcd8a2b, + /// 4D tensor blocked by 2nd dimension with block size 8 + aBcd8b = dnnl_aBcd8b, + ABcd8b16a2b = dnnl_ABcd8b16a2b, + AcdB8b16a2b = dnnl_AcdB8b16a2b, + ABcd8b32a2b = dnnl_ABcd8b32a2b, + AcdB8b32a2b = dnnl_AcdB8b32a2b, + ABcd8b64a2b = dnnl_ABcd8b64a2b, + AcdB8b64a2b = dnnl_AcdB8b64a2b, + aBCd8b16c2b = dnnl_aBCd8b16c2b, + /// 4D tensor blocked by 1st and 2nd dimension with block size 8 + ABcd8b8a = dnnl_ABcd8b8a, + AcdB8b8a = dnnl_AcdB8b8a, + aBCd8b8c = dnnl_aBCd8b8c, + aBCd8b4c = dnnl_aBCd8b4c, + aBCd8c16b2c = dnnl_aBCd8c16b2c, + aBCd8c8b = dnnl_aBCd8c8b, + Abcde16a = dnnl_Abcde16a, + Abcde32a = dnnl_Abcde32a, + ABcde16a16b = dnnl_ABcde16a16b, + aBcde16b = dnnl_aBcde16b, + aBcde32b = dnnl_aBcde32b, + ABcde16b16a = dnnl_ABcde16b16a, + AcdeB16b16a = dnnl_AcdeB16b16a, + ABcde16b32a = dnnl_ABcde16b32a, + AcdeB16b32a = dnnl_AcdeB16b32a, + ABcde16b48a = dnnl_ABcde16b48a, + AcdeB16b48a = dnnl_AcdeB16b48a, + ABcde16b64a = dnnl_ABcde16b64a, + AcdeB16b64a = dnnl_AcdeB16b64a, + aBCde16b16c = dnnl_aBCde16b16c, + aBCde16c16b = dnnl_aBCde16c16b, + aBCde2c8b4c = dnnl_aBCde2c8b4c, + Abcde4a = dnnl_Abcde4a, + aBcde4b = dnnl_aBcde4b, + ABcde4b4a = dnnl_ABcde4b4a, + ABcde4a4b = dnnl_ABcde4a4b, + aBCde4b4c = dnnl_aBCde4b4c, + aBCde4c16b4c = dnnl_aBCde4c16b4c, + aBCde16b16c2b = dnnl_aBCde16b16c2b, + aBCde16c16b4c = dnnl_aBCde16c16b4c, + aBCde16c16b2c = dnnl_aBCde16c16b2c, + aBCdef16c16b2c = dnnl_aBCdef16c16b2c, + aBCde4c4b = dnnl_aBCde4c4b, + Abcde8a = dnnl_Abcde8a, + ABcde8a8b = dnnl_ABcde8a8b, + ABcde8a4b = dnnl_ABcde8a4b, + aBcde8b = dnnl_aBcde8b, + ABcde8b16a2b = dnnl_ABcde8b16a2b, + AcdeB8b16a2b = dnnl_AcdeB8b16a2b, + ABcde8b32a2b = dnnl_ABcde8b32a2b, + AcdeB8b32a2b = dnnl_AcdeB8b32a2b, + ABcde8b64a2b = dnnl_ABcde8b64a2b, + AcdeB8b64a2b = dnnl_AcdeB8b64a2b, + ABcde4b16a4b = dnnl_ABcde4b16a4b, + AcdeB4b16a4b = dnnl_AcdeB4b16a4b, + ABcde4b32a4b = dnnl_ABcde4b32a4b, + AcdeB4b32a4b = dnnl_AcdeB4b32a4b, + ABcde4b64a4b = dnnl_ABcde4b64a4b, + AcdeB4b64a4b = dnnl_AcdeB4b64a4b, + ABcde16b16a4b = dnnl_ABcde16b16a4b, + ABcde16b32a4b = dnnl_ABcde16b32a4b, + ABcde16b48a4b = dnnl_ABcde16b48a4b, + ABcde16b64a4b = dnnl_ABcde16b64a4b, + ABcde16b16a2b = dnnl_ABcde16b16a2b, + ABcde16b32a2b = dnnl_ABcde16b32a2b, + ABcde16b48a2b = dnnl_ABcde16b48a2b, + ABcde16b64a2b = dnnl_ABcde16b64a2b, + ABcde2b8a4b = dnnl_ABcde2b8a4b, + aBCde8b16c2b = dnnl_aBCde8b16c2b, + ABcde8b8a = dnnl_ABcde8b8a, + AcdeB8b8a = dnnl_AcdeB8b8a, + aBCde8b8c = dnnl_aBCde8b8c, + aBCde8b4c = dnnl_aBCde8b4c, + ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b, + ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b, + aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c, + aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c, + aBCde8c16b2c = dnnl_aBCde8c16b2c, + aBCde8c8b = dnnl_aBCde8c8b, + aBcdef16b = dnnl_aBcdef16b, + aBCdef16b16c = dnnl_aBCdef16b16c, + aBCdef16c16b = dnnl_aBCdef16c16b, + aBcdef4b = dnnl_aBcdef4b, + aBCdef2c8b4c = dnnl_aBCdef2c8b4c, + aBCdef4c4b = dnnl_aBCdef4c4b, + aBCdef4b4c = dnnl_aBCdef4b4c, + aBCdef8b8c = dnnl_aBCdef8b8c, + aBCdef8b4c = dnnl_aBCdef8b4c, + aBCdef8c16b2c = dnnl_aBCdef8c16b2c, + aBCdef4c16b4c = dnnl_aBCdef4c16b4c, + aBCdef8c8b = dnnl_aBCdef8c8b, + aBdc16b = dnnl_aBdc16b, + aBdc4b = dnnl_aBdc4b, + aBdc8b = dnnl_aBdc8b, + aBdC8b2c = dnnl_aBdC8b2c, + aBdC8b4c = dnnl_aBdC8b4c, + aBdec16b = dnnl_aBdec16b, + aBdec4b = dnnl_aBdec4b, + aBdec8b = dnnl_aBdec8b, + aBdeC8b2c = dnnl_aBdeC8b2c, + aBdeC8b4c = dnnl_aBdeC8b4c, + aBdefc16b = dnnl_aBdefc16b, + aCBdef16c16b = dnnl_aCBdef16c16b, + aCBdef8b8c = dnnl_aCBdef8b8c, + aCBdef16b16c = dnnl_aCBdef16b16c, + aBdefc4b = dnnl_aBdefc4b, + aBdefc8b = dnnl_aBdefc8b, + aBdefC8b2c = dnnl_aBdefC8b2c, + aBdefC8b4c = dnnl_aBdefC8b4c, + Acb16a = dnnl_Acb16a, + Acb4a = dnnl_Acb4a, + Acb8a = dnnl_Acb8a, + AcB8a2b = dnnl_AcB8a2b, + AcB8a4b = dnnl_AcB8a4b, + aCBd8b8c = dnnl_aCBd8b8c, + aCBd16b16c = dnnl_aCBd16b16c, + aCBd16c16b = dnnl_aCBd16c16b, + aCBde8b8c = dnnl_aCBde8b8c, + aCBde16b16c = dnnl_aCBde16b16c, + aCBde16c16b = dnnl_aCBde16c16b, + Acdb16a = dnnl_Acdb16a, + Acdb4a = dnnl_Acdb4a, + Acdb8a = dnnl_Acdb8a, + AcdB8a2b = dnnl_AcdB8a2b, + AcdB8a4b = dnnl_AcdB8a4b, + Acdeb16a = dnnl_Acdeb16a, + Acdeb4a = dnnl_Acdeb4a, + Acdeb8a = dnnl_Acdeb8a, + AcdeB8a2b = dnnl_AcdeB8a2b, + AcdeB8a4b = dnnl_AcdeB8a4b, + BAc8a8b = dnnl_BAc8a8b, + BAc16a16b = dnnl_BAc16a16b, + BAc16b16a = dnnl_BAc16b16a, + BAcd8a8b = dnnl_BAcd8a8b, + BAcd16a16b = dnnl_BAcd16a16b, + BAcd16b16a = dnnl_BAcd16b16a, + ABcd32a32b = dnnl_ABcd32a32b, + BAcde16b16a = dnnl_BAcde16b16a, + BAcde8a8b = dnnl_BAcde8a8b, + BAcde16a16b = dnnl_BAcde16a16b, + aBdec32b = dnnl_aBdec32b, + Abcdef16a = dnnl_Abcdef16a, + Abcdef32a = dnnl_Abcdef32a, + Acdb32a = dnnl_Acdb32a, + aBCd2b4c2b = dnnl_aBCd2b4c2b, + aBCde2b4c2b = dnnl_aBCde2b4c2b, + aBCdef2b4c2b = dnnl_aBCdef2b4c2b, + aBCd2c4b2c = dnnl_aBCd2c4b2c, + aBCde2c4b2c = dnnl_aBCde2c4b2c, + aBCdef2c4b2c = dnnl_aBCdef2c4b2c, + aBCd4b8c2b = dnnl_aBCd4b8c2b, + aBCde4b8c2b = dnnl_aBCde4b8c2b, + aBCdef4b8c2b = dnnl_aBCdef4b8c2b, + aBCd4c8b2c = dnnl_aBCd4c8b2c, + aBCde4c8b2c = dnnl_aBCde4c8b2c, + aBCdef4c8b2c = dnnl_aBCdef4c8b2c, + AB32a32b8a4b = dnnl_AB32a32b8a4b, + AB32a32b8a2b = dnnl_AB32a32b8a2b, + AB8a4b = dnnl_AB8a4b, + AB8a2b = dnnl_AB8a2b, + abDc16d = dnnl_abDc16d, + abDc32d = dnnl_abDc32d, + abDC16d4c = dnnl_abDC16d4c, + abDC32d4c = dnnl_abDC32d4c, + abCd32c = dnnl_abCd32c, + abdEc16e = dnnl_abdEc16e, + abdEc32e = dnnl_abdEc32e, + abdEC16e4c = dnnl_abdEC16e4c, + abdEC32e2c = dnnl_abdEC32e2c, + abdEC32e4c = dnnl_abdEC32e4c, + abdCe16c = dnnl_abdCe16c, + abdCe32c = dnnl_abdCe32c, + abdCE32c2e = dnnl_abdCE32c2e, + aBCdef16c16b4c = dnnl_aBCdef16c16b4c, + aBdC16b4c = dnnl_aBdC16b4c, + aBdeC16b4c = dnnl_aBdeC16b4c, + AcB16a4b = dnnl_AcB16a4b, + AcdB16a2b = dnnl_AcdB16a2b, + aBdefC16b4c = dnnl_aBdefC16b4c, + AcdeB16a4b = dnnl_AcdeB16a4b, + + Acb32a = dnnl_Acb32a, + AcB32a2b = dnnl_AcB32a2b, + AcB32a4b = dnnl_AcB32a4b, + Acb48a = dnnl_Acb48a, + AcB48a2b = dnnl_AcB48a2b, + AcB48a4b = dnnl_AcB48a4b, + Acb64a = dnnl_Acb64a, + AcB64a2b = dnnl_AcB64a2b, + AcB64a4b = dnnl_AcB64a4b, + cBa2b = dnnl_cBa2b, + cBa4b = dnnl_cBa4b, + aBdc32b = dnnl_aBdc32b, + aBdC32b2c = dnnl_aBdC32b2c, + aBdC32b4c = dnnl_aBdC32b4c, + aBdc48b = dnnl_aBdc48b, + aBdC48b2c = dnnl_aBdC48b2c, + aBdC48b4c = dnnl_aBdC48b4c, + aBdc64b = dnnl_aBdc64b, + aBdC64b2c = dnnl_aBdC64b2c, + aBdC64b4c = dnnl_aBdC64b4c, + adcb = dnnl_adcb, + adCb2c = dnnl_adCb2c, + adCb4c = dnnl_adCb4c, + AcdB32a2b = dnnl_AcdB32a2b, + AcdB32a4b = dnnl_AcdB32a4b, + Acdb48a = dnnl_Acdb48a, + AcdB48a2b = dnnl_AcdB48a2b, + AcdB48a4b = dnnl_AcdB48a4b, + Acdb64a = dnnl_Acdb64a, + AcdB64a2b = dnnl_AcdB64a2b, + AcdB64a4b = dnnl_AcdB64a4b, + cdBa2b = dnnl_cdBa2b, + cdBa4b = dnnl_cdBa4b, + aBdeC32b2c = dnnl_aBdeC32b2c, + aBdeC32b4c = dnnl_aBdeC32b4c, + aBdec48b = dnnl_aBdec48b, + aBdeC48b2c = dnnl_aBdeC48b2c, + aBdeC48b4c = dnnl_aBdeC48b4c, + aBdec64b = dnnl_aBdec64b, + aBdeC64b2c = dnnl_aBdeC64b2c, + aBdeC64b4c = dnnl_aBdeC64b4c, + adecb = dnnl_adecb, + adeCb2c = dnnl_adeCb2c, + adeCb4c = dnnl_adeCb4c, + Acdeb32a = dnnl_Acdeb32a, + AcdeB32a2b = dnnl_AcdeB32a2b, + AcdeB32a4b = dnnl_AcdeB32a4b, + Acdeb48a = dnnl_Acdeb48a, + AcdeB48a2b = dnnl_AcdeB48a2b, + AcdeB48a4b = dnnl_AcdeB48a4b, + Acdeb64a = dnnl_Acdeb64a, + AcdeB64a2b = dnnl_AcdeB64a2b, + AcdeB64a4b = dnnl_AcdeB64a4b, + cdeBa2b = dnnl_cdeBa2b, + cdeBa4b = dnnl_cdeBa4b, + aBdefc32b = dnnl_aBdefc32b, + aBdefC32b2c = dnnl_aBdefC32b2c, + aBdefC32b4c = dnnl_aBdefC32b4c, + aBdefc48b = dnnl_aBdefc48b, + aBdefC48b2c = dnnl_aBdefC48b2c, + aBdefC48b4c = dnnl_aBdefC48b4c, + aBdefc64b = dnnl_aBdefc64b, + aBdefC64b2c = dnnl_aBdefC64b2c, + aBdefC64b4c = dnnl_aBdefC64b4c, + adefcb = dnnl_adefcb, + adefCb2c = dnnl_adefCb2c, + adefCb4c = dnnl_adefCb4c, + ABc32a32b = dnnl_ABc32a32b, + BAc8a16b2a = dnnl_BAc8a16b2a, + BAcd8a16b2a = dnnl_BAcd8a16b2a, + ABcde8a16b2a = dnnl_ABcde8a16b2a, + aCBd8b16c2b = dnnl_aCBd8b16c2b, + BAcde8a16b2a = dnnl_BAcde8a16b2a, + aCBde8b16c2b = dnnl_aCBde8b16c2b, + ABcde32a32b = dnnl_ABcde32a32b, + ABc4a8b8a4b = dnnl_ABc4a8b8a4b, + ABcde4a8b8a4b = dnnl_ABcde4a8b8a4b, + BAc4b8a8b4a = dnnl_BAc4b8a8b4a, + BAcd4b8a8b4a = dnnl_BAcd4b8a8b4a, + BAcde4b8a8b4a = dnnl_BAcde4b8a8b4a, + aBCd4b8c8b4c = dnnl_aBCd4b8c8b4c, + aBCdef4b8c8b4c = dnnl_aBCdef4b8c8b4c, + aBCdef8b16c2b = dnnl_aBCdef8b16c2b, + aCBdef8b16c2b = dnnl_aCBdef8b16c2b, + aBdC16b2c = dnnl_aBdC16b2c, + aBdeC16b2c = dnnl_aBdeC16b2c, + aBdefC16b2c = dnnl_aBdefC16b2c, + aBedc16b = dnnl_aBedc16b, + AcB16a2b = dnnl_AcB16a2b, + AcdB16a4b = dnnl_AcdB16a4b, + AcdeB16a2b = dnnl_AcdeB16a2b, + Adcb16a = dnnl_Adcb16a, + aCBd4c8b8c4b = dnnl_aCBd4c8b8c4b, + aCBde4c8b8c4b = dnnl_aCBde4c8b8c4b, + aCBdef4c8b8c4b = dnnl_aCBdef4c8b8c4b, + ABc32a16b = dnnl_ABc32a16b, + ABcd16a32b = dnnl_ABcd16a32b, + ABcd32a16b = dnnl_ABcd32a16b, + ABcde32a16b = dnnl_ABcde32a16b, + AB48a16b = dnnl_AB48a16b, + AB48a32b = dnnl_AB48a32b, + ABc40a16b = dnnl_ABc40a16b, + ABc40a32b = dnnl_ABc40a32b, + aBC48b16c = dnnl_aBC48b16c, + aBC48b32c = dnnl_aBC48b32c, + ABcd40a16b = dnnl_ABcd40a16b, + ABcd40a32b = dnnl_ABcd40a32b, + BA16a16b = dnnl_BA16a16b, + BA16a32b = dnnl_BA16a32b, + BA16a48b = dnnl_BA16a48b, + BA16a64b = dnnl_BA16a64b, + BA16a16b2a = dnnl_BA16a16b2a, + BA16a32b2a = dnnl_BA16a32b2a, + BA16a48b2a = dnnl_BA16a48b2a, + BA16a64b2a = dnnl_BA16a64b2a, + BA16a16b4a = dnnl_BA16a16b4a, + BA16a32b4a = dnnl_BA16a32b4a, + BA16a48b4a = dnnl_BA16a48b4a, + BA16a64b4a = dnnl_BA16a64b4a, + decbA16a = dnnl_decbA16a, + decbA8a = dnnl_decbA8a, + defcbA16a = dnnl_defcbA16a, + defcbA8a = dnnl_defcbA8a, + aCB16b16c = dnnl_aCB16b16c, + aCB16b32c = dnnl_aCB16b32c, + aCB16b48c = dnnl_aCB16b48c, + aCB16b64c = dnnl_aCB16b64c, + aCB16b16c2b = dnnl_aCB16b16c2b, + aCB16b32c2b = dnnl_aCB16b32c2b, + aCB16b48c2b = dnnl_aCB16b48c2b, + aCB16b64c2b = dnnl_aCB16b64c2b, + aCB16b16c4b = dnnl_aCB16b16c4b, + aCB16b32c4b = dnnl_aCB16b32c4b, + aCB16b48c4b = dnnl_aCB16b48c4b, + aCB16b64c4b = dnnl_aCB16b64c4b, + Acb24a = dnnl_Acb24a, + Acdb24a = dnnl_Acdb24a, + Acdeb24a = dnnl_Acdeb24a, + aBdc24b = dnnl_aBdc24b, + aBdec24b = dnnl_aBdec24b, + aBdefc24b = dnnl_aBdefc24b, + AcB24a2b = dnnl_AcB24a2b, + AcdB24a2b = dnnl_AcdB24a2b, + AcdeB24a2b = dnnl_AcdeB24a2b, + aBdC24b2c = dnnl_aBdC24b2c, + aBdeC24b2c = dnnl_aBdeC24b2c, + aBdefC24b2c = dnnl_aBdefC24b2c, + AcB24a4b = dnnl_AcB24a4b, + AcdB24a4b = dnnl_AcdB24a4b, + AcdeB24a4b = dnnl_AcdeB24a4b, + aBdC24b4c = dnnl_aBdC24b4c, + aBdeC24b4c = dnnl_aBdeC24b4c, + aBdefC24b4c = dnnl_aBdefC24b4c, + AB8b32a = dnnl_AB8b32a, + ABc8b32a = dnnl_ABc8b32a, + AcB8b32a = dnnl_AcB8b32a, + ABcd8b32a = dnnl_ABcd8b32a, + AcdB8b32a = dnnl_AcdB8b32a, + ABcde8b32a = dnnl_ABcde8b32a, + AcdeB8b32a = dnnl_AcdeB8b32a, + AB8b24a = dnnl_AB8b24a, + ABc8b24a = dnnl_ABc8b24a, + AcB8b24a = dnnl_AcB8b24a, + ABcd8b24a = dnnl_ABcd8b24a, + AcdB8b24a = dnnl_AcdB8b24a, + ABcde8b24a = dnnl_ABcde8b24a, + AcdeB8b24a = dnnl_AcdeB8b24a, + AB8b16a = dnnl_AB8b16a, + ABc8b16a = dnnl_ABc8b16a, + AcB8b16a = dnnl_AcB8b16a, + ABcd8b16a = dnnl_ABcd8b16a, + AcdB8b16a = dnnl_AcdB8b16a, + ABcde8b16a = dnnl_ABcde8b16a, + AcdeB8b16a = dnnl_AcdeB8b16a, + AB8b8a = dnnl_AB8b8a, + + format_tag_last = dnnl_format_tag_last, + + nCdhw16c = dnnl_nCdhw16c, + nCdhw4c = dnnl_nCdhw4c, + nCdhw8c = dnnl_nCdhw8c, + nChw16c = dnnl_nChw16c, + nChw4c = dnnl_nChw4c, + nChw8c = dnnl_nChw8c, + nCw16c = dnnl_nCw16c, + nCw4c = dnnl_nCw4c, + nCw8c = dnnl_nCw8c, + NCw16n16c = dnnl_NCw16n16c, + NChw16n16c = dnnl_NChw16n16c, + NCdhw16n16c = dnnl_NCdhw16n16c, + NCdhw32n32c = dnnl_NCdhw32n32c, + NChw32n32c = dnnl_NChw32n32c, + IOhw16i16o = dnnl_IOhw16i16o, + OI16i16o = dnnl_OI16i16o, + OI16i32o = dnnl_OI16i32o, + OI16i48o = dnnl_OI16i48o, + OI16i64o = dnnl_OI16i64o, + OI8i16o2i = dnnl_OI8i16o2i, + OI8i32o2i = dnnl_OI8i32o2i, + OI8i64o2i = dnnl_OI8i64o2i, + OI4i8o4i = dnnl_OI4i8o4i, + OI4i16o4i = dnnl_OI4i16o4i, + OI4i24o4i = dnnl_OI4i24o4i, + OI4i32o4i = dnnl_OI4i32o4i, + OI4i64o4i = dnnl_OI4i64o4i, + Ohwi32o = dnnl_Ohwi32o, + IOdhw16i16o = dnnl_IOdhw16i16o, + gIOhw16i16o = dnnl_gIOhw16i16o, + gOhwi32o = dnnl_gOhwi32o, + Goidhw16g = dnnl_Goidhw16g, + IOw8o8i = dnnl_IOw8o8i, + IOw16o16i = dnnl_IOw16o16i, + OIw16i16o = dnnl_OIw16i16o, + OwI16i16o = dnnl_OwI16i16o, + OIw16i32o = dnnl_OIw16i32o, + OwI16i32o = dnnl_OwI16i32o, + OIw16i48o = dnnl_OIw16i48o, + OwI16i48o = dnnl_OwI16i48o, + OIw16i64o = dnnl_OIw16i64o, + OwI16i64o = dnnl_OwI16i64o, + IOw16i16o = dnnl_IOw16i16o, + gIOw16i16o = dnnl_gIOw16i16o, + OIw16o16i = dnnl_OIw16o16i, + Oiw16o = dnnl_Oiw16o, + OIw4i8o4i = dnnl_OIw4i8o4i, + OwI4i8o4i = dnnl_OwI4i8o4i, + OIw4i16o4i = dnnl_OIw4i16o4i, + OwI4i16o4i = dnnl_OwI4i16o4i, + OIw4i24o4i = dnnl_OIw4i24o4i, + OwI4i24o4i = dnnl_OwI4i24o4i, + OIw4i32o4i = dnnl_OIw4i32o4i, + OwI4i32o4i = dnnl_OwI4i32o4i, + OIw4i64o4i = dnnl_OIw4i64o4i, + OwI4i64o4i = dnnl_OwI4i64o4i, + OIw2i8o4i = dnnl_OIw2i8o4i, + OIw4i4o = dnnl_OIw4i4o, + OIw4o4i = dnnl_OIw4o4i, + Oiw4o = dnnl_Oiw4o, + OIw8i16o2i = dnnl_OIw8i16o2i, + OwI8i16o2i = dnnl_OwI8i16o2i, + OIw8i32o2i = dnnl_OIw8i32o2i, + OwI8i32o2i = dnnl_OwI8i32o2i, + OIw8i64o2i = dnnl_OIw8i64o2i, + OwI8i64o2i = dnnl_OwI8i64o2i, + OIw8i8o = dnnl_OIw8i8o, + OwI8i8o = dnnl_OwI8i8o, + OIw8o16i2o = dnnl_OIw8o16i2o, + OIw8o8i = dnnl_OIw8o8i, + OIw8o4i = dnnl_OIw8o4i, + OIw16i16o4i = dnnl_OIw16i16o4i, + OIw16i32o4i = dnnl_OIw16i32o4i, + OIw16i48o4i = dnnl_OIw16i48o4i, + OIw16i64o4i = dnnl_OIw16i64o4i, + OIw16i16o2i = dnnl_OIw16i16o2i, + OIw16i32o2i = dnnl_OIw16i32o2i, + OIw16i48o2i = dnnl_OIw16i48o2i, + OIw16i64o2i = dnnl_OIw16i64o2i, + OIw16o16i2o = dnnl_OIw16o16i2o, + Owi16o = dnnl_Owi16o, + OwI16o2i = dnnl_OwI16o2i, + Iwo16i = dnnl_Iwo16i, + IwO16i2o = dnnl_IwO16i2o, + IwO16i4o = dnnl_IwO16i4o, + Owi4o = dnnl_Owi4o, + Owi8o = dnnl_Owi8o, + OwI8o2i = dnnl_OwI8o2i, + OwI8o4i = dnnl_OwI8o4i, + IOhw8o8i = dnnl_IOhw8o8i, + IOhw16o16i = dnnl_IOhw16o16i, + Ohwi16o = dnnl_Ohwi16o, + OhwI16o2i = dnnl_OhwI16o2i, + Ihwo16i = dnnl_Ihwo16i, + IhwO16i2o = dnnl_IhwO16i2o, + IhwO16i4o = dnnl_IhwO16i4o, + Ohwi4o = dnnl_Ohwi4o, + Ohwi8o = dnnl_Ohwi8o, + OhwI8o2i = dnnl_OhwI8o2i, + OhwI8o4i = dnnl_OhwI8o4i, + OIhw16i16o = dnnl_OIhw16i16o, + OhwI16i16o = dnnl_OhwI16i16o, + OIhw16i32o = dnnl_OIhw16i32o, + OhwI16i32o = dnnl_OhwI16i32o, + OIhw16i48o = dnnl_OIhw16i48o, + OhwI16i48o = dnnl_OhwI16i48o, + OIhw16i64o = dnnl_OIhw16i64o, + OhwI16i64o = dnnl_OhwI16i64o, + OIhw16o16i = dnnl_OIhw16o16i, + Oihw16o = dnnl_Oihw16o, + OIhw4i8o4i = dnnl_OIhw4i8o4i, + OhwI4i8o4i = dnnl_OhwI4i8o4i, + OIhw4i16o4i = dnnl_OIhw4i16o4i, + OhwI4i16o4i = dnnl_OhwI4i16o4i, + OIhw4i24o4i = dnnl_OIhw4i24o4i, + OhwI4i24o4i = dnnl_OhwI4i24o4i, + OIhw4i32o4i = dnnl_OIhw4i32o4i, + OhwI4i32o4i = dnnl_OhwI4i32o4i, + OIhw4i64o4i = dnnl_OIhw4i64o4i, + OhwI4i64o4i = dnnl_OhwI4i64o4i, + OIhw4i4o = dnnl_OIhw4i4o, + OIhw4o4i = dnnl_OIhw4o4i, + Oihw4o = dnnl_Oihw4o, + OIhw8i16o2i = dnnl_OIhw8i16o2i, + OhwI8i16o2i = dnnl_OhwI8i16o2i, + OIhw8i32o2i = dnnl_OIhw8i32o2i, + OhwI8i32o2i = dnnl_OhwI8i32o2i, + OIhw8i64o2i = dnnl_OIhw8i64o2i, + OhwI8i64o2i = dnnl_OhwI8i64o2i, + OIhw8i8o = dnnl_OIhw8i8o, + OhwI8i8o = dnnl_OhwI8i8o, + OIhw8o16i2o = dnnl_OIhw8o16i2o, + OIhw8o8i = dnnl_OIhw8o8i, + OIhw8o4i = dnnl_OIhw8o4i, + OIhw2i8o4i = dnnl_OIhw2i8o4i, + IOdhw8o8i = dnnl_IOdhw8o8i, + IOdhw16o16i = dnnl_IOdhw16o16i, + Odhwi16o = dnnl_Odhwi16o, + OdhwI16o2i = dnnl_OdhwI16o2i, + Idhwo16i = dnnl_Idhwo16i, + IdhwO16i2o = dnnl_IdhwO16i2o, + IdhwO16i4o = dnnl_IdhwO16i4o, + Odhwi4o = dnnl_Odhwi4o, + Odhwi8o = dnnl_Odhwi8o, + OdhwI8o2i = dnnl_OdhwI8o2i, + OdhwI8o4i = dnnl_OdhwI8o4i, + OIdhw16i16o = dnnl_OIdhw16i16o, + OdhwI16i16o = dnnl_OdhwI16i16o, + OIdhw16i32o = dnnl_OIdhw16i32o, + OdhwI16i32o = dnnl_OdhwI16i32o, + OIdhw16i48o = dnnl_OIdhw16i48o, + OdhwI16i48o = dnnl_OdhwI16i48o, + OIdhw16i64o = dnnl_OIdhw16i64o, + OdhwI16i64o = dnnl_OdhwI16i64o, + OIdhw16o16i = dnnl_OIdhw16o16i, + OIdhw16o16i2o = dnnl_OIdhw16o16i2o, + Oidhw16o = dnnl_Oidhw16o, + OIdhw4i4o = dnnl_OIdhw4i4o, + OIdhw4o4i = dnnl_OIdhw4o4i, + Oidhw4o = dnnl_Oidhw4o, + OIdhw8i16o2i = dnnl_OIdhw8i16o2i, + OdhwI8i16o2i = dnnl_OdhwI8i16o2i, + OIdhw8i32o2i = dnnl_OIdhw8i32o2i, + OdhwI8i32o2i = dnnl_OdhwI8i32o2i, + OIdhw8i64o2i = dnnl_OIdhw8i64o2i, + OdhwI8i64o2i = dnnl_OdhwI8i64o2i, + OIdhw4i8o4i = dnnl_OIdhw4i8o4i, + OdhwI4i8o4i = dnnl_OdhwI4i8o4i, + OIdhw4i16o4i = dnnl_OIdhw4i16o4i, + OdhwI4i16o4i = dnnl_OdhwI4i16o4i, + OIdhw16i16o4i = dnnl_OIdhw16i16o4i, + OIdhw16i32o4i = dnnl_OIdhw16i32o4i, + OIdhw16i48o4i = dnnl_OIdhw16i48o4i, + OIdhw16i64o4i = dnnl_OIdhw16i64o4i, + OIdhw16i16o2i = dnnl_OIdhw16i16o2i, + OIdhw16i32o2i = dnnl_OIdhw16i32o2i, + OIdhw16i48o2i = dnnl_OIdhw16i48o2i, + OIdhw16i64o2i = dnnl_OIdhw16i64o2i, + OIdhw4i24o4i = dnnl_OIdhw4i24o4i, + OdhwI4i24o4i = dnnl_OdhwI4i24o4i, + OIdhw4i32o4i = dnnl_OIdhw4i32o4i, + OdhwI4i32o4i = dnnl_OdhwI4i32o4i, + OIdhw4i64o4i = dnnl_OIdhw4i64o4i, + OdhwI4i64o4i = dnnl_OdhwI4i64o4i, + OIdhw2i8o4i = dnnl_OIdhw2i8o4i, + OIdhw8i8o = dnnl_OIdhw8i8o, + OdhwI8i8o = dnnl_OdhwI8i8o, + OIdhw8o8i = dnnl_OIdhw8o8i, + OIdhw8o4i = dnnl_OIdhw8o4i, + gIOw8o8i = dnnl_gIOw8o8i, + gIOw16o16i = dnnl_gIOw16o16i, + gOIw16i16o = dnnl_gOIw16i16o, + gOIw16o16i = dnnl_gOIw16o16i, + gOiw16o = dnnl_gOiw16o, + gOIw4i16o4i = dnnl_gOIw4i16o4i, + gOIw2i8o4i = dnnl_gOIw2i8o4i, + gOIw4i4o = dnnl_gOIw4i4o, + gOIw4o4i = dnnl_gOIw4o4i, + gOiw4o = dnnl_gOiw4o, + gOIw8i16o2i = dnnl_gOIw8i16o2i, + gOIw8i8o = dnnl_gOIw8i8o, + gOIw8o16i2o = dnnl_gOIw8o16i2o, + gOIw8o8i = dnnl_gOIw8o8i, + gOIw8o4i = dnnl_gOIw8o4i, + gOIw16i16o4i = dnnl_gOIw16i16o4i, + gOIw16i16o2i = dnnl_gOIw16i16o2i, + gOIw16o16i2o = dnnl_gOIw16o16i2o, + gOwi16o = dnnl_gOwi16o, + gOwI16o2i = dnnl_gOwI16o2i, + gIwo16i = dnnl_gIwo16i, + gIwO16i2o = dnnl_gIwO16i2o, + gIwO16i4o = dnnl_gIwO16i4o, + gOwi4o = dnnl_gOwi4o, + gOwi8o = dnnl_gOwi8o, + gOwI8o2i = dnnl_gOwI8o2i, + gOwI8o4i = dnnl_gOwI8o4i, + Goiw8g = dnnl_Goiw8g, + Goiw16g = dnnl_Goiw16g, + gIOhw8o8i = dnnl_gIOhw8o8i, + gIOhw16o16i = dnnl_gIOhw16o16i, + gOhwi16o = dnnl_gOhwi16o, + gOhwI16o2i = dnnl_gOhwI16o2i, + gIhwo16i = dnnl_gIhwo16i, + gIhwO16i2o = dnnl_gIhwO16i2o, + gIhwO16i4o = dnnl_gIhwO16i4o, + gOhwi4o = dnnl_gOhwi4o, + gOhwi8o = dnnl_gOhwi8o, + gOhwI8o2i = dnnl_gOhwI8o2i, + gOhwI8o4i = dnnl_gOhwI8o4i, + Goihw16g = dnnl_Goihw16g, + gOIhw16i16o = dnnl_gOIhw16i16o, + gOIhw16o16i = dnnl_gOIhw16o16i, + gOihw16o = dnnl_gOihw16o, + gOIhw4i16o4i = dnnl_gOIhw4i16o4i, + gOIhw2i8o4i = dnnl_gOIhw2i8o4i, + gOIhw4i4o = dnnl_gOIhw4i4o, + gOIhw4o4i = dnnl_gOIhw4o4i, + gOihw4o = dnnl_gOihw4o, + Goihw8g = dnnl_Goihw8g, + gOIhw8i16o2i = dnnl_gOIhw8i16o2i, + gOIhw8i8o = dnnl_gOIhw8i8o, + gOIhw8o16i2o = dnnl_gOIhw8o16i2o, + OIw4o8i8o4i = dnnl_OIw4o8i8o4i, + OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i, + OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i, + OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i, + gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i, + gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i, + gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i, + gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i, + OIhw16i16o4i = dnnl_OIhw16i16o4i, + OIhw16i32o4i = dnnl_OIhw16i32o4i, + OIhw16i48o4i = dnnl_OIhw16i48o4i, + OIhw16i64o4i = dnnl_OIhw16i64o4i, + OIhw16i16o2i = dnnl_OIhw16i16o2i, + OIhw16i32o2i = dnnl_OIhw16i32o2i, + OIhw16i48o2i = dnnl_OIhw16i48o2i, + OIhw16i64o2i = dnnl_OIhw16i64o2i, + OIhw16o16i2o = dnnl_OIhw16o16i2o, + gOIhw16i16o4i = dnnl_gOIhw16i16o4i, + gOIhw16i16o2i = dnnl_gOIhw16i16o2i, + gOIhw16o16i2o = dnnl_gOIhw16o16i2o, + gOIhw8o8i = dnnl_gOIhw8o8i, + gOIhw8o4i = dnnl_gOIhw8o4i, + gIOdhw16i16o = dnnl_gIOdhw16i16o, + gIOdhw8o8i = dnnl_gIOdhw8o8i, + gIOdhw16o16i = dnnl_gIOdhw16o16i, + gOdhwi16o = dnnl_gOdhwi16o, + gOdhwI16o2i = dnnl_gOdhwI16o2i, + gIdhwo16i = dnnl_gIdhwo16i, + gIdhwO16i2o = dnnl_gIdhwO16i2o, + gIdhwO16i4o = dnnl_gIdhwO16i4o, + gOdhwi4o = dnnl_gOdhwi4o, + gOdhwi8o = dnnl_gOdhwi8o, + gOdhwI8o2i = dnnl_gOdhwI8o2i, + gOdhwI8o4i = dnnl_gOdhwI8o4i, + gOIdhw16i16o = dnnl_gOIdhw16i16o, + gOIdhw16o16i = dnnl_gOIdhw16o16i, + gOIdhw16o16i2o = dnnl_gOIdhw16o16i2o, + gOidhw16o = dnnl_gOidhw16o, + gOIdhw4i4o = dnnl_gOIdhw4i4o, + gOIdhw4o4i = dnnl_gOIdhw4o4i, + gOidhw4o = dnnl_gOidhw4o, + gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i, + gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i, + gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i, + gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i, + gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i, + gOIdhw8i8o = dnnl_gOIdhw8i8o, + gOIdhw8o8i = dnnl_gOIdhw8o8i, + gOIdhw8o4i = dnnl_gOIdhw8o4i, + gOIw2i4o2i = dnnl_gOIw2i4o2i, + gOIhw2i4o2i = dnnl_gOIhw2i4o2i, + gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i, + gOIw2o4i2o = dnnl_gOIw2o4i2o, + gOIhw2o4i2o = dnnl_gOIhw2o4i2o, + gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o, + gOIw4i8o2i = dnnl_gOIw4i8o2i, + gOIhw4i8o2i = dnnl_gOIhw4i8o2i, + gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i, + gOIw4o8i2o = dnnl_gOIw4o8i2o, + gOIhw4o8i2o = dnnl_gOIhw4o8i2o, + gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o, + + ldOi16o = abDc16d, + ldOi32o = abDc32d, + ldOI16o4i = abDC16d4c, + ldOI32o4i = abDC32d4c, + ldgOi16o = abdEc16e, + ldgOI16o4i = abdEC16e4c, + ldgOi32o = abdEc32e, + ldgOI32o2i = abdEC32e2c, + ldgOI32o4i = abdEC32e4c, + OwI16o4i = dnnl_OwI16o4i, + OhwI16o4i = dnnl_OhwI16o4i, + gOwI16o4i = dnnl_gOwI16o4i, + gOhwI16o4i = dnnl_gOhwI16o4i, + OdhwI16o4i = dnnl_OdhwI16o4i, + gOdhwI16o4i = dnnl_gOdhwI16o4i, + + Owi32o = dnnl_Owi32o, + OwI32o2i = dnnl_OwI32o2i, + OwI32o4i = dnnl_OwI32o4i, + Owi48o = dnnl_Owi48o, + OwI48o2i = dnnl_OwI48o2i, + OwI48o4i = dnnl_OwI48o4i, + Owi64o = dnnl_Owi64o, + OwI64o2i = dnnl_OwI64o2i, + OwI64o4i = dnnl_OwI64o4i, + Iwo32i = dnnl_Iwo32i, + IwO32i2o = dnnl_IwO32i2o, + IwO32i4o = dnnl_IwO32i4o, + Iwo48i = dnnl_Iwo48i, + IwO48i2o = dnnl_IwO48i2o, + IwO48i4o = dnnl_IwO48i4o, + Iwo64i = dnnl_Iwo64i, + IwO64i2o = dnnl_IwO64i2o, + IwO64i4o = dnnl_IwO64i4o, + wIo2i = dnnl_wIo2i, + wIo4i = dnnl_wIo4i, + gOwi32o = dnnl_gOwi32o, + gOwI32o2i = dnnl_gOwI32o2i, + gOwI32o4i = dnnl_gOwI32o4i, + gOwi48o = dnnl_gOwi48o, + gOwI48o2i = dnnl_gOwI48o2i, + gOwI48o4i = dnnl_gOwI48o4i, + gOwi64o = dnnl_gOwi64o, + gOwI64o2i = dnnl_gOwI64o2i, + gOwI64o4i = dnnl_gOwI64o4i, + gIwo32i = dnnl_gIwo32i, + gIwO32i2o = dnnl_gIwO32i2o, + gIwO32i4o = dnnl_gIwO32i4o, + gIwo48i = dnnl_gIwo48i, + gIwO48i2o = dnnl_gIwO48i2o, + gIwO48i4o = dnnl_gIwO48i4o, + gIwo64i = dnnl_gIwo64i, + gIwO64i2o = dnnl_gIwO64i2o, + gIwO64i4o = dnnl_gIwO64i4o, + gwio = dnnl_gwio, + gwIo2i = dnnl_gwIo2i, + gwIo4i = dnnl_gwIo4i, + OhwI32o = dnnl_OhwI32o, + OhwI32o2i = dnnl_OhwI32o2i, + OhwI32o4i = dnnl_OhwI32o4i, + Ohwi48o = dnnl_Ohwi48o, + OhwI48o2i = dnnl_OhwI48o2i, + OhwI48o4i = dnnl_OhwI48o4i, + Ohwi64o = dnnl_Ohwi64o, + OhwI64o2i = dnnl_OhwI64o2i, + OhwI64o4i = dnnl_OhwI64o4i, + Ihwo32i = dnnl_Ihwo32i, + IhwO32i2o = dnnl_IhwO32i2o, + IhwO32i4o = dnnl_IhwO32i4o, + Ihwo48i = dnnl_Ihwo48i, + IhwO48i2o = dnnl_IhwO48i2o, + IhwO48i4o = dnnl_IhwO48i4o, + Ihwo64i = dnnl_Ihwo64i, + IhwO64i2o = dnnl_IhwO64i2o, + IhwO64i4o = dnnl_IhwO64i4o, + hwIo2i = dnnl_hwIo2i, + hwIo4i = dnnl_hwIo4i, + gOhwI32o = dnnl_gOhwI32o, + gOhwI32o2i = dnnl_gOhwI32o2i, + gOhwI32o4i = dnnl_gOhwI32o4i, + gOhwi48o = dnnl_gOhwi48o, + gOhwI48o2i = dnnl_gOhwI48o2i, + gOhwI48o4i = dnnl_gOhwI48o4i, + gOhwi64o = dnnl_gOhwi64o, + gOhwI64o2i = dnnl_gOhwI64o2i, + gOhwI64o4i = dnnl_gOhwI64o4i, + gIhwo32i = dnnl_gIhwo32i, + gIhwO32i2o = dnnl_gIhwO32i2o, + gIhwO32i4o = dnnl_gIhwO32i4o, + gIhwo48i = dnnl_gIhwo48i, + gIhwO48i2o = dnnl_gIhwO48i2o, + gIhwO48i4o = dnnl_gIhwO48i4o, + gIhwo64i = dnnl_gIhwo64i, + gIhwO64i2o = dnnl_gIhwO64i2o, + gIhwO64i4o = dnnl_gIhwO64i4o, + ghwio = dnnl_ghwio, + ghwIo2i = dnnl_ghwIo2i, + ghwIo4i = dnnl_ghwIo4i, + Odhwi32o = dnnl_Odhwi32o, + OdhwI32o2i = dnnl_OdhwI32o2i, + OdhwI32o4i = dnnl_OdhwI32o4i, + Odhwi48o = dnnl_Odhwi48o, + OdhwI48o2i = dnnl_OdhwI48o2i, + OdhwI48o4i = dnnl_OdhwI48o4i, + Odhwi64o = dnnl_Odhwi64o, + OdhwI64o2i = dnnl_OdhwI64o2i, + OdhwI64o4i = dnnl_OdhwI64o4i, + Idhwo32i = dnnl_Idhwo32i, + IdhwO32i2o = dnnl_IdhwO32i2o, + IdhwO32i4o = dnnl_IdhwO32i4o, + Idhwo48i = dnnl_Idhwo48i, + IdhwO48i2o = dnnl_IdhwO48i2o, + IdhwO48i4o = dnnl_IdhwO48i4o, + Idhwo64i = dnnl_Idhwo64i, + IdhwO64i2o = dnnl_IdhwO64i2o, + IdhwO64i4o = dnnl_IdhwO64i4o, + dhwIo2i = dnnl_dhwIo2i, + dhwIo4i = dnnl_dhwIo4i, + gOdhwi32o = dnnl_gOdhwi32o, + gOdhwI32o2i = dnnl_gOdhwI32o2i, + gOdhwI32o4i = dnnl_gOdhwI32o4i, + gOdhwi48o = dnnl_gOdhwi48o, + gOdhwI48o2i = dnnl_gOdhwI48o2i, + gOdhwI48o4i = dnnl_gOdhwI48o4i, + gOdhwi64o = dnnl_gOdhwi64o, + gOdhwI64o2i = dnnl_gOdhwI64o2i, + gOdhwI64o4i = dnnl_gOdhwI64o4i, + gIdhwo32i = dnnl_gIdhwo32i, + gIdhwO32i2o = dnnl_gIdhwO32i2o, + gIdhwO32i4o = dnnl_gIdhwO32i4o, + gIdhwo48i = dnnl_gIdhwo48i, + gIdhwO48i2o = dnnl_gIdhwO48i2o, + gIdhwO48i4o = dnnl_gIdhwO48i4o, + gIdhwo64i = dnnl_gIdhwo64i, + gIdhwO64i2o = dnnl_gIdhwO64i2o, + gIdhwO64i4o = dnnl_gIdhwO64i4o, + gdhwio = dnnl_gdhwio, + gdhwIo2i = dnnl_gdhwIo2i, + gdhwIo4i = dnnl_gdhwIo4i, + ldIo32i = dnnl_ldIo32i, + ldgIo16i = dnnl_ldgIo16i, + ldgIo32i = dnnl_ldgIo32i, + ldgIO32i2o = dnnl_ldgIO32i2o, + nCdhw32c = dnnl_nCdhw32c, + nChw32c = dnnl_nChw32c, + nCw32c = dnnl_nCw32c, + NCw32n16c = dnnl_NCw32n16c, + NChw32n16c = dnnl_NChw32n16c, + NCdhw32n16c = dnnl_NCdhw32n16c, + NCw32n32c = dnnl_NCw32n32c, + OI16i16o4i = dnnl_OI16i16o4i, + IOw8o16i2o = dnnl_IOw8o16i2o, + IOhw8o16i2o = dnnl_IOhw8o16i2o, + Owhi16o = dnnl_Owhi16o, + OIdhw8o16i2o = dnnl_OIdhw8o16i2o, + IOdhw8o16i2o = dnnl_IOdhw8o16i2o, + Goiw4g = dnnl_Goiw4g, + gIOw8o16i2o = dnnl_gIOw8o16i2o, + Goiw32g = dnnl_Goiw32g, + Goihw4g = dnnl_Goihw4g, + gIOhw8o16i2o = dnnl_gIOhw8o16i2o, + Goihw32g = dnnl_Goihw32g, + gOwhi16o = dnnl_gOwhi16o, + IOw4i8o8i4o = dnnl_IOw4i8o8i4o, + IOhw4i8o8i4o = dnnl_IOhw4i8o8i4o, + IOdhw4i8o8i4o = dnnl_IOdhw4i8o8i4o, + gIOw4i8o8i4o = dnnl_gIOw4i8o8i4o, + gIOhw4i8o8i4o = dnnl_gIOhw4i8o8i4o, + gIOdhw4i8o8i4o = dnnl_gIOdhw4i8o8i4o, + gOIdhw8o16i2o = dnnl_gOIdhw8o16i2o, + gIOdhw8o16i2o = dnnl_gIOdhw8o16i2o, + Goidhw32g = dnnl_Goidhw32g, + OI16i32o4i = dnnl_OI16i32o4i, + OI16i48o4i = dnnl_OI16i48o4i, + OI16i64o4i = dnnl_OI16i64o4i, + OI16i16o2i = dnnl_OI16i16o2i, + OI16i32o2i = dnnl_OI16i32o2i, + OI16i48o2i = dnnl_OI16i48o2i, + OI16i64o2i = dnnl_OI16i64o2i, + aBdeC16c16b4c = dnnl_aBdeC16c16b4c, + AcB16b16a2b = dnnl_AcB16b16a2b, + aBdC16c16b2c = dnnl_aBdC16c16b2c, + AcB16b16a4b = dnnl_AcB16b16a4b, + aBdC16c16b4c = dnnl_aBdC16c16b4c, + AcdB16b16a2b = dnnl_AcdB16b16a2b, + aBdefC16c16b4c = dnnl_aBdefC16c16b4c, + AcdeB16b16a4b = dnnl_AcdeB16b16a4b, + AcB16b32a2b = dnnl_AcB16b32a2b, + AcB16b32a4b = dnnl_AcB16b32a4b, + AcB16b48a2b = dnnl_AcB16b48a2b, + AcB16b48a4b = dnnl_AcB16b48a4b, + AcB16b64a2b = dnnl_AcB16b64a2b, + AcB16b64a4b = dnnl_AcB16b64a4b, + aBdC16c32b2c = dnnl_aBdC16c32b2c, + aBdC16c32b4c = dnnl_aBdC16c32b4c, + aBdC16c48b2c = dnnl_aBdC16c48b2c, + aBdC16c48b4c = dnnl_aBdC16c48b4c, + aBdC16c64b2c = dnnl_aBdC16c64b2c, + aBdC16c64b4c = dnnl_aBdC16c64b4c, + AcdB16b32a2b = dnnl_AcdB16b32a2b, + AcdB16b32a4b = dnnl_AcdB16b32a4b, + AcdB16b48a2b = dnnl_AcdB16b48a2b, + AcdB16b48a4b = dnnl_AcdB16b48a4b, + AcdB16b64a2b = dnnl_AcdB16b64a2b, + AcdB16b64a4b = dnnl_AcdB16b64a4b, + aBdeC16c32b2c = dnnl_aBdeC16c32b2c, + aBdeC16c32b4c = dnnl_aBdeC16c32b4c, + aBdeC16c48b2c = dnnl_aBdeC16c48b2c, + aBdeC16c48b4c = dnnl_aBdeC16c48b4c, + aBdeC16c64b2c = dnnl_aBdeC16c64b2c, + aBdeC16c64b4c = dnnl_aBdeC16c64b4c, + AcdeB16b32a2b = dnnl_AcdeB16b32a2b, + AcdeB16b32a4b = dnnl_AcdeB16b32a4b, + AcdeB16b48a2b = dnnl_AcdeB16b48a2b, + AcdeB16b48a4b = dnnl_AcdeB16b48a4b, + AcdeB16b64a2b = dnnl_AcdeB16b64a2b, + AcdeB16b64a4b = dnnl_AcdeB16b64a4b, + aBdefC16c32b2c = dnnl_aBdefC16c32b2c, + aBdefC16c32b4c = dnnl_aBdefC16c32b4c, + aBdefC16c48b2c = dnnl_aBdefC16c48b2c, + aBdefC16c48b4c = dnnl_aBdefC16c48b4c, + aBdefC16c64b2c = dnnl_aBdefC16c64b2c, + aBdefC16c64b4c = dnnl_aBdefC16c64b4c, + OwI16i16o2i = dnnl_OwI16i16o2i, + gOwI16i16o2i = dnnl_gOwI16i16o2i, + OhwI16i16o2i = dnnl_OhwI16i16o2i, + gOhwI16i16o2i = dnnl_gOhwI16i16o2i, + OdhwI16i16o2i = dnnl_OdhwI16i16o2i, + gOdhwI16i16o2i = dnnl_gOdhwI16i16o2i, + OwI16i16o4i = dnnl_OwI16i16o4i, + gOwI16i16o4i = dnnl_gOwI16i16o4i, + OhwI16i16o4i = dnnl_OhwI16i16o4i, + gOhwI16i16o4i = dnnl_gOhwI16i16o4i, + OdhwI16i16o4i = dnnl_OdhwI16i16o4i, + gOdhwI16i16o4i = dnnl_gOdhwI16i16o4i, + OwI16i32o2i = dnnl_OwI16i32o2i, + OwI16i32o4i = dnnl_OwI16i32o4i, + OwI16i48o2i = dnnl_OwI16i48o2i, + OwI16i48o4i = dnnl_OwI16i48o4i, + OwI16i64o2i = dnnl_OwI16i64o2i, + OwI16i64o4i = dnnl_OwI16i64o4i, + gOwI16i32o2i = dnnl_gOwI16i32o2i, + gOwI16i32o4i = dnnl_gOwI16i32o4i, + gOwI16i48o2i = dnnl_gOwI16i48o2i, + gOwI16i48o4i = dnnl_gOwI16i48o4i, + gOwI16i64o2i = dnnl_gOwI16i64o2i, + gOwI16i64o4i = dnnl_gOwI16i64o4i, + OhwI16i32o2i = dnnl_OhwI16i32o2i, + OhwI16i32o4i = dnnl_OhwI16i32o4i, + OhwI16i48o2i = dnnl_OhwI16i48o2i, + OhwI16i48o4i = dnnl_OhwI16i48o4i, + OhwI16i64o2i = dnnl_OhwI16i64o2i, + OhwI16i64o4i = dnnl_OhwI16i64o4i, + gOhwI16i32o2i = dnnl_gOhwI16i32o2i, + gOhwI16i32o4i = dnnl_gOhwI16i32o4i, + gOhwI16i48o2i = dnnl_gOhwI16i48o2i, + gOhwI16i48o4i = dnnl_gOhwI16i48o4i, + gOhwI16i64o2i = dnnl_gOhwI16i64o2i, + gOhwI16i64o4i = dnnl_gOhwI16i64o4i, + OdhwI16i32o2i = dnnl_OdhwI16i32o2i, + OdhwI16i32o4i = dnnl_OdhwI16i32o4i, + OdhwI16i48o2i = dnnl_OdhwI16i48o2i, + OdhwI16i48o4i = dnnl_OdhwI16i48o4i, + OdhwI16i64o2i = dnnl_OdhwI16i64o2i, + OdhwI16i64o4i = dnnl_OdhwI16i64o4i, + IdhwO16o32i2o = dnnl_IdhwO16o32i2o, + IdhwO16o32i4o = dnnl_IdhwO16o32i4o, + IdhwO16o48i2o = dnnl_IdhwO16o48i2o, + IdhwO16o48i4o = dnnl_IdhwO16o48i4o, + IdhwO16o64i2o = dnnl_IdhwO16o64i2o, + IdhwO16o64i4o = dnnl_IdhwO16o64i4o, + gOdhwI16i32o2i = dnnl_gOdhwI16i32o2i, + gOdhwI16i32o4i = dnnl_gOdhwI16i32o4i, + gOdhwI16i48o2i = dnnl_gOdhwI16i48o2i, + gOdhwI16i48o4i = dnnl_gOdhwI16i48o4i, + gOdhwI16i64o2i = dnnl_gOdhwI16i64o2i, + gOdhwI16i64o4i = dnnl_gOdhwI16i64o4i, + gIdhwO16o32i2o = dnnl_gIdhwO16o32i2o, + gIdhwO16o32i4o = dnnl_gIdhwO16o32i4o, + gIdhwO16o48i2o = dnnl_gIdhwO16o48i2o, + gIdhwO16o48i4o = dnnl_gIdhwO16o48i4o, + gIdhwO16o64i2o = dnnl_gIdhwO16o64i2o, + gIdhwO16o64i4o = dnnl_gIdhwO16o64i4o, + IwO16o16i2o = dnnl_IwO16o16i2o, + IwO16o16i4o = dnnl_IwO16o16i4o, + IhwO16o16i2o = dnnl_IhwO16o16i2o, + IhwO16o16i4o = dnnl_IhwO16o16i4o, + IdhwO16o16i2o = dnnl_IdhwO16o16i2o, + IdhwO16o16i4o = dnnl_IdhwO16o16i4o, + gIwO16o16i2o = dnnl_gIwO16o16i2o, + gIwO16o16i4o = dnnl_gIwO16o16i4o, + gIhwO16o16i2o = dnnl_gIhwO16o16i2o, + gIhwO16o16i4o = dnnl_gIhwO16o16i4o, + gIdhwO16o16i2o = dnnl_gIdhwO16o16i2o, + gIdhwO16o16i4o = dnnl_gIdhwO16o16i4o, + IwO16o32i2o = dnnl_IwO16o32i2o, + IwO16o32i4o = dnnl_IwO16o32i4o, + IwO16o48i2o = dnnl_IwO16o48i2o, + IwO16o48i4o = dnnl_IwO16o48i4o, + IwO16o64i2o = dnnl_IwO16o64i2o, + IwO16o64i4o = dnnl_IwO16o64i4o, + gIwO16o32i2o = dnnl_gIwO16o32i2o, + gIwO16o32i4o = dnnl_gIwO16o32i4o, + gIwO16o48i2o = dnnl_gIwO16o48i2o, + gIwO16o48i4o = dnnl_gIwO16o48i4o, + gIwO16o64i2o = dnnl_gIwO16o64i2o, + gIwO16o64i4o = dnnl_gIwO16o64i4o, + IhwO16o32i2o = dnnl_IhwO16o32i2o, + IhwO16o32i4o = dnnl_IhwO16o32i4o, + IhwO16o48i2o = dnnl_IhwO16o48i2o, + IhwO16o48i4o = dnnl_IhwO16o48i4o, + IhwO16o64i2o = dnnl_IhwO16o64i2o, + IhwO16o64i4o = dnnl_IhwO16o64i4o, + gIhwO16o32i2o = dnnl_gIhwO16o32i2o, + gIhwO16o32i4o = dnnl_gIhwO16o32i4o, + gIhwO16o48i2o = dnnl_gIhwO16o48i2o, + gIhwO16o48i4o = dnnl_gIhwO16o48i4o, + gIhwO16o64i2o = dnnl_gIhwO16o64i2o, + gIhwO16o64i4o = dnnl_gIhwO16o64i4o, + aBdeC16c16b2c = dnnl_aBdeC16c16b2c, + aBdefC16c16b2c = dnnl_aBdefC16c16b2c, + AcdB16b16a4b = dnnl_AcdB16b16a4b, + AcdeB16b16a2b = dnnl_AcdeB16b16a2b, + hwioG16g = dnnl_hwioG16g, + hwioG8g = dnnl_hwioG8g, + dhwioG16g = dnnl_dhwioG16g, + dhwioG8g = dnnl_dhwioG8g, + ABc4a2b = dnnl_ABc4a2b, + ABc8a2b = dnnl_ABc8a2b, + ABcd4a2b = dnnl_ABcd4a2b, + ABcde4a2b = dnnl_ABcde4a2b, + ABcde8a2b = dnnl_ABcde8a2b, + ABcd4a8b8a2b = dnnl_ABcd4a8b8a2b, + NCdhw40n32c = dnnl_NCdhw40n32c, + NChw40n32c = dnnl_NChw40n32c, + NCw40n32c = dnnl_NCw40n32c, + OIdhw4o8i8o2i = dnnl_OIdhw4o8i8o2i, + OIhw4o8i8o2i = dnnl_OIhw4o8i8o2i, + OIw4o8i8o2i = dnnl_OIw4o8i8o2i, + gOIdhw4o8i8o2i = dnnl_gOIdhw4o8i8o2i, + gOIhw4o8i8o2i = dnnl_gOIhw4o8i8o2i, + gOIw4o8i8o2i = dnnl_gOIw4o8i8o2i, + IOdhw4i8o8i2o = dnnl_IOdhw4i8o8i2o, + IOhw4i8o8i2o = dnnl_IOhw4i8o8i2o, + IOw4i8o8i2o = dnnl_IOw4i8o8i2o, + gIOdhw4i8o8i2o = dnnl_gIOdhw4i8o8i2o, + gIOhw4i8o8i2o = dnnl_gIOhw4i8o8i2o, + gIOw4i8o8i2o = dnnl_gIOw4i8o8i2o, + aBCd8b2c = dnnl_aBCd8b2c, + ABcde40a16b = dnnl_ABcde40a16b, + ABcde40a32b = dnnl_ABcde40a32b, + aBCde8b2c = dnnl_aBCde8b2c, + ABcde4a8b8a2b = dnnl_ABcde4a8b8a2b, + ABc4a8b8a2b = dnnl_ABc4a8b8a2b, + aBCdef4b8c8b2c = dnnl_aBCdef4b8c8b2c, + aBCde4b8c8b2c = dnnl_aBCde4b8c8b2c, + aBCd4b8c8b2c = dnnl_aBCd4b8c8b2c, + BAcde4b8a8b2a = dnnl_BAcde4b8a8b2a, + BAcd4b8a8b2a = dnnl_BAcd4b8a8b2a, + BAc4b8a8b2a = dnnl_BAc4b8a8b2a, + aCBdef4c8b8c2b = dnnl_aCBdef4c8b8c2b, + aCBde4c8b8c2b = dnnl_aCBde4c8b8c2b, + aCBd4c8b8c2b = dnnl_aCBd4c8b8c2b, + aBCdef8b2c = dnnl_aBCdef8b2c, + AB32a16b = dnnl_AB32a16b, + AB32a32b = dnnl_AB32a32b, + BA4b8a8b2a = dnnl_BA4b8a8b2a, + BA4b8a8b4a = dnnl_BA4b8a8b4a, + aBC32b16c = dnnl_aBC32b16c, + aBC32b32c = dnnl_aBC32b32c, + aCB4c8b8c2b = dnnl_aCB4c8b8c2b, + aCB4c8b8c4b = dnnl_aCB4c8b8c4b, + ABc2b8a16b4a = dnnl_ABc2b8a16b4a, + ABcd2b8a16b4a = dnnl_ABcd2b8a16b4a, + ABcde2b8a16b4a = dnnl_ABcde2b8a16b4a, + ABc2a8b16a4b = dnnl_ABc2a8b16a4b, + ABc2a8b16a2b = dnnl_ABc2a8b16a2b, + ABc2b32a8b = dnnl_ABc2b32a8b, + ABcd2a8b16a4b = dnnl_ABcd2a8b16a4b, + ABcd2a8b16a2b = dnnl_ABcd2a8b16a2b, + aCBd2c8b16c2b = dnnl_aCBd2c8b16c2b, + ABcd2b32a8b = dnnl_ABcd2b32a8b, + aBCd2c8b16c2b = dnnl_aBCd2c8b16c2b, + ABcde2a8b16a4b = dnnl_ABcde2a8b16a4b, + ABcde2a8b16a2b = dnnl_ABcde2a8b16a2b, + aCBde2c8b16c2b = dnnl_aCBde2c8b16c2b, + ABcde2b32a8b = dnnl_ABcde2b32a8b, + aBC2b8c16b2c = dnnl_aBC2b8c16b2c, + aBCd2b8c16b2c = dnnl_aBCd2b8c16b2c, + aBCde2b8c16b2c = dnnl_aBCde2b8c16b2c, + aBCdef2b8c16b2c = dnnl_aBCdef2b8c16b2c, + BAcde2b8a16b4a = dnnl_BAcde2b8a16b4a, + BAcd2b8a16b4a = dnnl_BAcd2b8a16b4a, + BAc2b8a16b4a = dnnl_BAc2b8a16b4a, + BAcde2b8a16b2a = dnnl_BAcde2b8a16b2a, + BAcd2b8a16b2a = dnnl_BAcd2b8a16b2a, + BAc2b8a16b2a = dnnl_BAc2b8a16b2a, + aBCde2c8b16c2b = dnnl_aBCde2c8b16c2b, + aBCdef2c8b16c2b = dnnl_aBCdef2c8b16c2b, + aCBdef2c8b16c2b = dnnl_aCBdef2c8b16c2b, + aBCd2b8c16b4c = dnnl_aBCd2b8c16b4c, + aBCde2b8c16b4c = dnnl_aBCde2b8c16b4c, + NCdhw40n16c = dnnl_NCdhw40n16c, + NCw40n16c = dnnl_NCw40n16c, + NChw40n16c = dnnl_NChw40n16c, + NCw2c32n8c = dnnl_NCw2c32n8c, + NChw2c32n8c = dnnl_NChw2c32n8c, + NCdhw2c32n8c = dnnl_NCdhw2c32n8c, + OIw2i8o16i4o = dnnl_OIw2i8o16i4o, + OIhw2i8o16i4o = dnnl_OIhw2i8o16i4o, + OIdhw2i8o16i4o = dnnl_OIdhw2i8o16i4o, + OIw2o8i16o4i = dnnl_OIw2o8i16o4i, + OIw2o8i16o2i = dnnl_OIw2o8i16o2i, + IOw2i8o16i4o = dnnl_IOw2i8o16i4o, + IOw2i8o16i2o = dnnl_IOw2i8o16i2o, + OIhw2o8i16o4i = dnnl_OIhw2o8i16o4i, + OIhw2o8i16o2i = dnnl_OIhw2o8i16o2i, + IOhw2i8o16i4o = dnnl_IOhw2i8o16i4o, + IOhw2i8o16i2o = dnnl_IOhw2i8o16i2o, + OIdhw2o8i16o4i = dnnl_OIdhw2o8i16o4i, + OIdhw2o8i16o2i = dnnl_OIdhw2o8i16o2i, + IOdhw2i8o16i4o = dnnl_IOdhw2i8o16i4o, + IOdhw2i8o16i2o = dnnl_IOdhw2i8o16i2o, + gOIw2o8i16o2i = dnnl_gOIw2o8i16o2i, + gIOw2i8o16i2o = dnnl_gIOw2i8o16i2o, + gIOhw2i8o16i2o = dnnl_gIOhw2i8o16i2o, + gIOdhw2i8o16i2o = dnnl_gIOdhw2i8o16i2o, + gOIhw2o8i16o2i = dnnl_gOIhw2o8i16o2i, + gOIdhw2o8i16o2i = dnnl_gOIdhw2o8i16o2i, + gOIw2o8i16o4i = dnnl_gOIw2o8i16o4i, + gOIhw2o8i16o4i = dnnl_gOIhw2o8i16o4i, + BA4b8a16b2a = dnnl_BA4b8a16b2a, + BA4b8a16b4a = dnnl_BA4b8a16b4a, + aCB4c8b16c2b = dnnl_aCB4c8b16c2b, + aCB4c8b16c4b = dnnl_aCB4c8b16c4b, + aCB16c2b = dnnl_aCB16c2b, + aCB16c4b = dnnl_aCB16c4b, + BA16b2a = dnnl_BA16b2a, + BA16b4a = dnnl_BA16b4a, + BA4b4a = dnnl_BA4b4a, + BA8b4a = dnnl_BA8b4a, + aBC16b16c = dnnl_aBC16b16c, + aBC16b32c = dnnl_aBC16b32c, + AB16a16b = dnnl_AB16a16b, + AB16a32b = dnnl_AB16a32b, + ABcde16a16b2a = dnnl_ABcde16a16b2a, + aBCdef16b16c2b = dnnl_aBCdef16b16c2b, + Acedb16a = dnnl_Acedb16a, + aBdfec16b = dnnl_aBdfec16b, + Odwhi16o = dnnl_Odwhi16o, + gOdwhi16o = dnnl_gOdwhi16o, + abdEC64e2c = dnnl_abdEC64e2c, + abdEC64e4c = dnnl_abdEC64e4c, + ldgOI64o2i = abdEC64e2c, + ldgOI64o4i = abdEC64e4c, + abCd4c = dnnl_abCd4c, + abCde4c = dnnl_abCde4c, + abCdef4c = dnnl_abCdef4c, + abCde32c = dnnl_abCde32c, + abCdef32c = dnnl_abCdef32c, + aCdefB16b32c2b = dnnl_aCdefB16b32c2b, + aCdefB16b32c4b = dnnl_aCdefB16b32c4b, + aCdefB16b48c2b = dnnl_aCdefB16b48c2b, + aCdefB16b48c4b = dnnl_aCdefB16b48c4b, + aCdefB16b64c2b = dnnl_aCdefB16b64c2b, + aCdefB16b64c4b = dnnl_aCdefB16b64c4b, + BcdeA16a32b2a = dnnl_BcdeA16a32b2a, + BcdeA16a32b4a = dnnl_BcdeA16a32b4a, + BcdeA16a48b2a = dnnl_BcdeA16a48b2a, + BcdeA16a48b4a = dnnl_BcdeA16a48b4a, + BcdeA16a64b2a = dnnl_BcdeA16a64b2a, + BcdeA16a64b4a = dnnl_BcdeA16a64b4a, + aCdefb32c = dnnl_aCdefb32c, + aCdefB32c2b = dnnl_aCdefB32c2b, + aCdefB32c4b = dnnl_aCdefB32c4b, + aCdefb48c = dnnl_aCdefb48c, + aCdefB48c2b = dnnl_aCdefB48c2b, + aCdefB48c4b = dnnl_aCdefB48c4b, + aCdefb64c = dnnl_aCdefb64c, + aCdefB64c2b = dnnl_aCdefB64c2b, + aCdefB64c4b = dnnl_aCdefB64c4b, + Bcdea32b = dnnl_Bcdea32b, + BcdeA32b2a = dnnl_BcdeA32b2a, + BcdeA32b4a = dnnl_BcdeA32b4a, + Bcdea48b = dnnl_Bcdea48b, + BcdeA48b2a = dnnl_BcdeA48b2a, + BcdeA48b4a = dnnl_BcdeA48b4a, + Bcdea64b = dnnl_Bcdea64b, + BcdeA64b2a = dnnl_BcdeA64b2a, + BcdeA64b4a = dnnl_BcdeA64b4a, + Bca32b = dnnl_Bca32b, + BcA32b2a = dnnl_BcA32b2a, + BcA32b4a = dnnl_BcA32b4a, + Bca48b = dnnl_Bca48b, + BcA48b2a = dnnl_BcA48b2a, + BcA48b4a = dnnl_BcA48b4a, + Bca64b = dnnl_Bca64b, + BcA64b2a = dnnl_BcA64b2a, + BcA64b4a = dnnl_BcA64b4a, + aCdb32c = dnnl_aCdb32c, + aCdB32c2b = dnnl_aCdB32c2b, + aCdB32c4b = dnnl_aCdB32c4b, + aCdb48c = dnnl_aCdb48c, + aCdB48c2b = dnnl_aCdB48c2b, + aCdB48c4b = dnnl_aCdB48c4b, + aCdb64c = dnnl_aCdb64c, + aCdB64c2b = dnnl_aCdB64c2b, + aCdB64c4b = dnnl_aCdB64c4b, + BcA16a16b2a = dnnl_BcA16a16b2a, + BcA16a16b4a = dnnl_BcA16a16b4a, + BcdA16a16b2a = dnnl_BcdA16a16b2a, + BcdA16a16b4a = dnnl_BcdA16a16b4a, + BcdeA16a16b2a = dnnl_BcdeA16a16b2a, + BcdeA16a16b4a = dnnl_BcdeA16a16b4a, + aCdB16b16c2b = dnnl_aCdB16b16c2b, + aCdB16b16c4b = dnnl_aCdB16b16c4b, + aCdeB16b16c2b = dnnl_aCdeB16b16c2b, + aCdeB16b16c4b = dnnl_aCdeB16b16c4b, + aCdefB16b16c2b = dnnl_aCdefB16b16c2b, + aCdefB16b16c4b = dnnl_aCdefB16b16c4b, + BcA16a32b2a = dnnl_BcA16a32b2a, + BcA16a32b4a = dnnl_BcA16a32b4a, + BcA16a48b2a = dnnl_BcA16a48b2a, + BcA16a48b4a = dnnl_BcA16a48b4a, + BcA16a64b2a = dnnl_BcA16a64b2a, + BcA16a64b4a = dnnl_BcA16a64b4a, + aCdB16b32c2b = dnnl_aCdB16b32c2b, + aCdB16b32c4b = dnnl_aCdB16b32c4b, + aCdB16b48c2b = dnnl_aCdB16b48c2b, + aCdB16b48c4b = dnnl_aCdB16b48c4b, + aCdB16b64c2b = dnnl_aCdB16b64c2b, + aCdB16b64c4b = dnnl_aCdB16b64c4b, + BcdA16a32b2a = dnnl_BcdA16a32b2a, + BcdA16a32b4a = dnnl_BcdA16a32b4a, + BcdA16a48b2a = dnnl_BcdA16a48b2a, + BcdA16a48b4a = dnnl_BcdA16a48b4a, + BcdA16a64b2a = dnnl_BcdA16a64b2a, + BcdA16a64b4a = dnnl_BcdA16a64b4a, + aCdeB16b32c2b = dnnl_aCdeB16b32c2b, + aCdeB16b32c4b = dnnl_aCdeB16b32c4b, + aCdeB16b48c2b = dnnl_aCdeB16b48c2b, + aCdeB16b48c4b = dnnl_aCdeB16b48c4b, + aCdeB16b64c2b = dnnl_aCdeB16b64c2b, + aCdeB16b64c4b = dnnl_aCdeB16b64c4b, + Bca16b = dnnl_Bca16b, + BcA16b2a = dnnl_BcA16b2a, + BcA16b4a = dnnl_BcA16b4a, + Bcda16b = dnnl_Bcda16b, + BcdA16b2a = dnnl_BcdA16b2a, + BcdA16b4a = dnnl_BcdA16b4a, + Bcdea16b = dnnl_Bcdea16b, + BcdeA16b2a = dnnl_BcdeA16b2a, + BcdeA16b4a = dnnl_BcdeA16b4a, + aCdb16c = dnnl_aCdb16c, + aCdB16c2b = dnnl_aCdB16c2b, + aCdB16c4b = dnnl_aCdB16c4b, + aCdeb16c = dnnl_aCdeb16c, + aCdeB16c2b = dnnl_aCdeB16c2b, + aCdeB16c4b = dnnl_aCdeB16c4b, + aCdefb16c = dnnl_aCdefb16c, + aCdefB16c2b = dnnl_aCdefB16c2b, + aCdefB16c4b = dnnl_aCdefB16c4b, + Bcda32b = dnnl_Bcda32b, + BcdA32b2a = dnnl_BcdA32b2a, + BcdA32b4a = dnnl_BcdA32b4a, + Bcda48b = dnnl_Bcda48b, + BcdA48b2a = dnnl_BcdA48b2a, + BcdA48b4a = dnnl_BcdA48b4a, + Bcda64b = dnnl_Bcda64b, + BcdA64b2a = dnnl_BcdA64b2a, + BcdA64b4a = dnnl_BcdA64b4a, + aCdeb32c = dnnl_aCdeb32c, + aCdeB32c2b = dnnl_aCdeB32c2b, + aCdeB32c4b = dnnl_aCdeB32c4b, + aCdeb48c = dnnl_aCdeb48c, + aCdeB48c2b = dnnl_aCdeB48c2b, + aCdeB48c4b = dnnl_aCdeB48c4b, + aCdeb64c = dnnl_aCdeb64c, + aCdeB64c2b = dnnl_aCdeB64c2b, + aCdeB64c4b = dnnl_aCdeB64c4b, + NChw16n32c = dnnl_NChw16n32c, + goIw4i = dnnl_goIw4i, + goIw32i = dnnl_goIw32i, + goIhw4i = dnnl_goIhw4i, + goIhw32i = dnnl_goIhw32i, + goIdhw4i = dnnl_goIdhw4i, + goIdhw32i = dnnl_goIdhw32i, + cab = dnnl_cab, + cdab = dnnl_cdab, + cdeab = dnnl_cdeab, + woi = dnnl_woi, + hwoi = dnnl_hwoi, + dhwoi = dnnl_dhwoi, + Owi24o = dnnl_Owi24o, + Ohwi24o = dnnl_Ohwi24o, + Odhwi24o = dnnl_Odhwi24o, + gOwi24o = dnnl_gOwi24o, + gOhwi24o = dnnl_gOhwi24o, + gOdhwi24o = dnnl_gOdhwi24o, + OwI24o2i = dnnl_OwI24o2i, + OhwI24o2i = dnnl_OhwI24o2i, + OdhwI24o2i = dnnl_OdhwI24o2i, + gOwI24o2i = dnnl_gOwI24o2i, + gOhwI24o2i = dnnl_gOhwI24o2i, + gOdhwI24o2i = dnnl_gOdhwI24o2i, + OwI24o4i = dnnl_OwI24o4i, + OhwI24o4i = dnnl_OhwI24o4i, + OdhwI24o4i = dnnl_OdhwI24o4i, + gOwI24o4i = dnnl_gOwI24o4i, + gOhwI24o4i = dnnl_gOhwI24o4i, + gOdhwI24o4i = dnnl_gOdhwI24o4i, + OI8i32o = dnnl_OI8i32o, + OIw8i32o = dnnl_OIw8i32o, + OwI8i32o = dnnl_OwI8i32o, + OIhw8i32o = dnnl_OIhw8i32o, + OhwI8i32o = dnnl_OhwI8i32o, + OIdhw8i32o = dnnl_OIdhw8i32o, + OdhwI8i32o = dnnl_OdhwI8i32o, + OI8i24o = dnnl_OI8i24o, + OIw8i24o = dnnl_OIw8i24o, + OwI8i24o = dnnl_OwI8i24o, + OIhw8i24o = dnnl_OIhw8i24o, + OhwI8i24o = dnnl_OhwI8i24o, + OIdhw8i24o = dnnl_OIdhw8i24o, + OdhwI8i24o = dnnl_OdhwI8i24o, + OI8i16o = dnnl_OI8i16o, + OIw8i16o = dnnl_OIw8i16o, + OwI8i16o = dnnl_OwI8i16o, + OIhw8i16o = dnnl_OIhw8i16o, + OhwI8i16o = dnnl_OhwI8i16o, + OIdhw8i16o = dnnl_OIdhw8i16o, + OdhwI8i16o = dnnl_OdhwI8i16o, + OI8i8o = dnnl_OI8i8o, + AB4b8a4b = dnnl_AB4b8a4b, + AB4b24a4b = dnnl_AB4b24a4b, + ABc4b8a4b = dnnl_ABc4b8a4b, + AcB4b8a4b = dnnl_AcB4b8a4b, + ABc4b24a4b = dnnl_ABc4b24a4b, + AcB4b24a4b = dnnl_AcB4b24a4b, + ABcd4b8a4b = dnnl_ABcd4b8a4b, + AcdB4b8a4b = dnnl_AcdB4b8a4b, + ABcd4b24a4b = dnnl_ABcd4b24a4b, + AcdB4b24a4b = dnnl_AcdB4b24a4b, + ABcde4b8a4b = dnnl_ABcde4b8a4b, + AcdeB4b8a4b = dnnl_AcdeB4b8a4b, + ABcde4b24a4b = dnnl_ABcde4b24a4b, + AcdeB4b24a4b = dnnl_AcdeB4b24a4b, + Bca8b = dnnl_Bca8b, + BcA8b2a = dnnl_BcA8b2a, + Bcda8b = dnnl_Bcda8b, + BcdA8b2a = dnnl_BcdA8b2a, + Bcdea8b = dnnl_Bcdea8b, + BcdeA8b2a = dnnl_BcdeA8b2a, + aCdb8c = dnnl_aCdb8c, + aCdB8c2b = dnnl_aCdB8c2b, + aCdeb8c = dnnl_aCdeb8c, + aCdeB8c2b = dnnl_aCdeB8c2b, + aCdefb8c = dnnl_aCdefb8c, + aCdefB8c2b = dnnl_aCdefB8c2b, + Bca24b = dnnl_Bca24b, + BcA24b2a = dnnl_BcA24b2a, + Bcda24b = dnnl_Bcda24b, + BcdA24b2a = dnnl_BcdA24b2a, + Bcdea24b = dnnl_Bcdea24b, + BcdeA24b2a = dnnl_BcdeA24b2a, + aCdb24c = dnnl_aCdb24c, + aCdB24c2b = dnnl_aCdB24c2b, + aCdeb24c = dnnl_aCdeb24c, + aCdeB24c2b = dnnl_aCdeB24c2b, + aCdefb24c = dnnl_aCdefb24c, + aCdefB24c2b = dnnl_aCdefB24c2b, + Iwo8i = dnnl_Iwo8i, + IwO8i2o = dnnl_IwO8i2o, + Iwo24i = dnnl_Iwo24i, + IwO24i2o = dnnl_IwO24i2o, + Ihwo8i = dnnl_Ihwo8i, + IhwO8i2o = dnnl_IhwO8i2o, + Ihwo24i = dnnl_Ihwo24i, + IhwO24i2o = dnnl_IhwO24i2o, + Idhwo8i = dnnl_Idhwo8i, + IdhwO8i2o = dnnl_IdhwO8i2o, + Idhwo24i = dnnl_Idhwo24i, + IdhwO24i2o = dnnl_IdhwO24i2o, + gIwo8i = dnnl_gIwo8i, + gIwO8i2o = dnnl_gIwO8i2o, + gIwo24i = dnnl_gIwo24i, + gIwO24i2o = dnnl_gIwO24i2o, + gIhwo8i = dnnl_gIhwo8i, + gIhwO8i2o = dnnl_gIhwO8i2o, + gIhwo24i = dnnl_gIhwo24i, + gIhwO24i2o = dnnl_gIhwO24i2o, + gIdhwo8i = dnnl_gIdhwo8i, + gIdhwO8i2o = dnnl_gIdhwO8i2o, + gIdhwo24i = dnnl_gIdhwo24i, + gIdhwO24i2o = dnnl_gIdhwO24i2o, + OhwI24o = dnnl_OhwI24o, + gOhwI24o = dnnl_gOhwI24o, + AB8b24a2b = dnnl_AB8b24a2b, + ABc8b24a2b = dnnl_ABc8b24a2b, + AcB8b24a2b = dnnl_AcB8b24a2b, + ABcd8b24a2b = dnnl_ABcd8b24a2b, + AcdB8b24a2b = dnnl_AcdB8b24a2b, + ABcde8b24a2b = dnnl_ABcde8b24a2b, + AcdeB8b24a2b = dnnl_AcdeB8b24a2b, + AB8b8a2b = dnnl_AB8b8a2b, + ABc8b8a2b = dnnl_ABc8b8a2b, + AcB8b8a2b = dnnl_AcB8b8a2b, + ABcd8b8a2b = dnnl_ABcd8b8a2b, + AcdB8b8a2b = dnnl_AcdB8b8a2b, + ABcde8b8a2b = dnnl_ABcde8b8a2b, + AcdeB8b8a2b = dnnl_AcdeB8b8a2b, + OI8i8o2i = dnnl_OI8i8o2i, + OI8i24o2i = dnnl_OI8i24o2i, + OIw8i8o2i = dnnl_OIw8i8o2i, + OwI8i8o2i = dnnl_OwI8i8o2i, + OIw8i24o2i = dnnl_OIw8i24o2i, + OwI8i24o2i = dnnl_OwI8i24o2i, + OIhw8i8o2i = dnnl_OIhw8i8o2i, + OhwI8i8o2i = dnnl_OhwI8i8o2i, + OIhw8i24o2i = dnnl_OIhw8i24o2i, + OhwI8i24o2i = dnnl_OhwI8i24o2i, + OIdhw8i8o2i = dnnl_OIdhw8i8o2i, + OdhwI8i8o2i = dnnl_OdhwI8i8o2i, + OIdhw8i24o2i = dnnl_OIdhw8i24o2i, + OdhwI8i24o2i = dnnl_OdhwI8i24o2i, + BcA8b4a = dnnl_BcA8b4a, + BcdA8b4a = dnnl_BcdA8b4a, + BcdeA8b4a = dnnl_BcdeA8b4a, + aCdB8c4b = dnnl_aCdB8c4b, + aCdeB8c4b = dnnl_aCdeB8c4b, + aCdefB8c4b = dnnl_aCdefB8c4b, + BcA24b4a = dnnl_BcA24b4a, + BcdA24b4a = dnnl_BcdA24b4a, + BcdeA24b4a = dnnl_BcdeA24b4a, + aCdB24c4b = dnnl_aCdB24c4b, + aCdeB24c4b = dnnl_aCdeB24c4b, + aCdefB24c4b = dnnl_aCdefB24c4b, + ABc16a4b = dnnl_ABc16a4b, + ABcd16a4b = dnnl_ABcd16a4b, + ABcde16a4b = dnnl_ABcde16a4b, + IwO8i4o = dnnl_IwO8i4o, + IwO24i4o = dnnl_IwO24i4o, + IhwO8i4o = dnnl_IhwO8i4o, + IhwO24i4o = dnnl_IhwO24i4o, + IdhwO8i4o = dnnl_IdhwO8i4o, + IdhwO24i4o = dnnl_IdhwO24i4o, + gIwO8i4o = dnnl_gIwO8i4o, + gIwO24i4o = dnnl_gIwO24i4o, + gIhwO8i4o = dnnl_gIhwO8i4o, + gIhwO24i4o = dnnl_gIhwO24i4o, + gIdhwO8i4o = dnnl_gIdhwO8i4o, + gIdhwO24i4o = dnnl_gIdhwO24i4o, + BA2a24b = dnnl_BA2a24b, + aCB2b24c = dnnl_aCB2b24c, + BA2a8b = dnnl_BA2a8b, + aCB2b8c = dnnl_aCB2b8c, + BA8a24b = dnnl_BA8a24b, + aCB8b24c = dnnl_aCB8b24c, + BA8a16b = dnnl_BA8a16b, + aCB8b16c = dnnl_aCB8b16c, + BA8a8b = dnnl_BA8a8b, + aCB8b8c = dnnl_aCB8b8c, + bcad = dnnl_bcad, + cabd = dnnl_cabd, + dabc = dnnl_dabc, + }; + + /// A memory descriptor. + struct desc : public handle { + using handle::handle; + + friend struct memory; + + /// Constructs a zero (empty) memory descriptor. Such a memory + /// descriptor can be used to indicate absence of an argument. + desc() { + dnnl_memory_desc_t zero_md = nullptr; + error::wrap_c_api( + dnnl_memory_desc_create_with_tag(&zero_md, 0, nullptr, + dnnl_data_type_undef, dnnl_format_tag_undef), + "could not create a zero memory descriptor"); + reset(zero_md); + } + + /// Constructs a memory descriptor. + /// + /// @note + /// The logical order of dimensions corresponds to the `abc...` + /// format tag, and the physical meaning of the dimensions depends + /// both on the primitive that would operate on this memory and + /// the operation context. + /// + /// @param adims Tensor dimensions. + /// @param adata_type Data precision/type. + /// @param aformat_tag Memory format tag. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be constructed. This flag is + /// optional and defaults to false. + desc(const dims &adims, data_type adata_type, format_tag aformat_tag, + bool allow_empty = false) { + validate_dims(adims); + dnnl_memory_desc_t md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_with_tag(&md, + (int)adims.size(), adims.data(), convert_to_c(adata_type), + convert_to_c(aformat_tag)); + if (!allow_empty) + error::wrap_c_api(status, + "could not construct a memory descriptor using a " + "format tag"); + reset(md); + } + + /// Constructs a memory descriptor by strides. + /// + /// @note + /// The logical order of dimensions corresponds to the `abc...` + /// format tag, and the physical meaning of the dimensions depends + /// both on the primitive that would operate on this memory and + /// the operation context. + /// + /// @param adims Tensor dimensions. + /// @param adata_type Data precision/type. + /// @param strides Strides for each dimension. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be constructed. This flag is + /// optional and defaults to false. + desc(const dims &adims, data_type adata_type, const dims &strides, + bool allow_empty = false) { + validate_dims(adims); + if (!strides.empty()) validate_dims(strides, (int)adims.size()); + dnnl_memory_desc_t md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_with_strides(&md, + (int)adims.size(), adims.data(), convert_to_c(adata_type), + strides.empty() ? nullptr : &strides[0]); + if (!allow_empty) + error::wrap_c_api(status, + "could not construct a memory descriptor using " + "strides"); + reset(md); + } +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Function for creating a memory descriptor for CSR sparse encoding. + /// + /// The created memory descriptor will describe a memory object that + /// contains 3 buffers. The buffers have the following meaning and + /// assigned numbers (index): + /// - 0: values + /// - 1: indices + /// - 2: pointers + /// + /// @param adims Tensor dimensions. + /// @param adata_type Data precision/type. + /// @param nnz Number of non-zero entries. + /// @param index_dt Data type of indices. + /// @param pointer_dt Data type of pointers. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be constructed. This flag is + /// optional and defaults to false. + static desc csr(const dims &adims, data_type adata_type, dim nnz, + data_type index_dt, data_type pointer_dt, + bool allow_empty = false) { + validate_dims(adims); + dnnl_memory_desc_t md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_with_csr_encoding( + &md, (int)adims.size(), adims.data(), + convert_to_c(adata_type), nnz, convert_to_c(index_dt), + convert_to_c(pointer_dt)); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a memory descriptor for CSR sparse " + "encoding"); + return desc {md}; + } + + /// Function for creating a memory descriptor for COO sparse encodings. + /// + /// The created memory descriptor will describe a memory object that + /// contains n+1 buffers for an n-dimensional tensor. + /// The buffers have the following meaning and assigned numbers (index): + /// - 0: values + /// - 1: indices for dimension 0 + /// - 2: indices for dimension 1 ... + /// - n: indices for dimension n-1 + /// + /// @param adims Tensor dimensions. + /// @param adata_type Data precision/type. + /// @param nnz Number of non-zero entries. + /// @param index_dt Data type of indices. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be constructed. This flag is + /// optional and defaults to false. + static desc coo(const dims &adims, data_type adata_type, dim nnz, + data_type index_dt, bool allow_empty = false) { + validate_dims(adims); + dnnl_memory_desc_t md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_with_coo_encoding( + &md, (int)adims.size(), adims.data(), + convert_to_c(adata_type), nnz, convert_to_c(index_dt)); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a memory descriptor for COO sparse " + "encoding"); + return desc {md}; + } + + /// Function for creating a memory descriptor for packed sparse + /// encoding. + /// + /// The created memory descriptor cannot be used to create a memory + /// object. It can only be used to create a primitive descriptor to + /// query the actual memory descriptor (similar to the format tag + /// `any`). + /// + /// @warning + /// The meaning and content of the handles of the memory object that + /// is created using the queried memory descriptor are unspecified + /// therefore using the content is an undefined behavior. + /// + /// @param adims Tensor dimensions. + /// @param adata_type Data precision/type. + /// @param nnz Number of non-zero entries. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be constructed. This flag is + /// optional and defaults to false. + static desc packed(const dims &adims, data_type adata_type, dim nnz, + bool allow_empty = false) { + validate_dims(adims); + dnnl_memory_desc_t md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_with_packed_encoding( + &md, (int)adims.size(), adims.data(), + convert_to_c(adata_type), nnz); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a memory descriptor for packed " + "sparse encoding"); + return desc {md}; + } +#endif + /// Construct a memory descriptor from a C API ::dnnl_memory_desc_t + /// handle. The resulting handle is not weak and the C handle will be + /// destroyed during the destruction of the C++ object. + /// + /// @param md The C API memory descriptor. + desc(dnnl_memory_desc_t md) : handle(md) {} + + /// Construct a memory descriptor from a binary blob. + /// + /// @param blob A binary blob previously queried from a memory descriptor. + desc(const std::vector &blob) { + dnnl_memory_desc_t md = nullptr; + error::wrap_c_api( + dnnl_memory_desc_create_with_blob(&md, blob.data()), + "could not create a memory descriptor from blob"); + reset(md); + } + + /// Constructs a memory descriptor for a region inside an area + /// described by this memory descriptor. + // + /// @param adims Sizes of the region. + /// @param offsets Offsets to the region from the encompassing + /// memory object in each dimension. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be returned. This flag is optional + /// and defaults to false. + /// @returns A memory descriptor for the region. + desc submemory_desc(const dims &adims, const dims &offsets, + bool allow_empty = false) const { + validate_dims(adims, get_ndims()); + validate_dims(offsets, get_ndims()); + dnnl_memory_desc_t sub_md = nullptr; + dnnl_status_t status = dnnl_memory_desc_create_submemory( + &sub_md, get(), adims.data(), offsets.data()); + if (!allow_empty) + error::wrap_c_api(status, "could not construct a sub-memory"); + return desc(sub_md); + } + + /// Constructs a memory descriptor by reshaping an existing one. The + /// new memory descriptor inherits the data type. This operation is + /// valid only for memory descriptors that have format_kind set to + /// #dnnl::memory::format_kind::blocked or + /// #dnnl::memory::format_kind::any. + /// + /// The operation ensures that the transformation of the physical memory + /// format corresponds to the transformation of the logical dimensions. + /// If such transformation is impossible, the function either throws an + /// exception (default) or returns a zero memory descriptor depending on + /// the `allow_empty` flag. + /// + /// The reshape operation can be described as a combination of the + /// following basic operations: + /// 1. Add a dimension of size `1`. This is always possible. + /// 2. Remove a dimension of size `1`. This is possible only if the + /// dimension has no padding (i.e. + /// `padded_dims[dim] == dims[dim] && dims[dim] == 1`). + /// 3. Split a dimension into multiple ones. This is possible only if + /// the product of all tensor dimensions stays constant and the + /// dimension being split does not have padding (i.e. + /// `padded_dims[dim] = dims[dim]`). + /// 4. Join multiple consecutive dimensions into a single one. As in + /// the cases above, this requires that the dimensions do not have + /// padding and that the memory format is such that in physical + /// memory these dimensions are dense and have the same order as + /// their logical counterparts. This also assumes that these + /// dimensions are not blocked. + /// - Here, 'dense' means: + /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`; + /// - And 'same order' means: + /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`. + /// + /// @warning + /// Some combinations of physical memory layout and/or offsets or + /// dimensions may result in a failure to make a reshape. + /// + /// @param adims New dimensions. The product of dimensions must + /// remain constant. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be returned. This flag is optional + /// and defaults to false. + /// @returns A new memory descriptor with new dimensions. + desc reshape(const dims &adims, bool allow_empty = false) const { + if (get_ndims()) validate_dims(adims, 1); + dnnl_memory_desc_t out_md = nullptr; + dnnl_status_t status = dnnl_memory_desc_reshape( + &out_md, get(), (int)adims.size(), adims.data()); + if (!allow_empty) + error::wrap_c_api( + status, "could not reshape a memory descriptor"); + return desc(out_md); + } + + /// Constructs a memory descriptor by permuting axes in an existing + /// one. + /// + /// The physical memory layout representation is adjusted accordingly + /// to maintain the consistency between the logical and physical parts + /// of the memory descriptor. The new memory descriptor inherits the + /// data type. + /// + /// The new memory descriptor inherits the data type. This operation is + /// valid only for memory descriptors that have format_kind set to + /// #dnnl::memory::format_kind::blocked or + /// #dnnl::memory::format_kind::any. + /// + /// The logical axes will be permuted in the following manner: + /// @code + /// for (i = 0; i < get_ndims(); i++) + /// new_desc.dims()[permutation[i]] = dims()[i]; + /// @endcode + /// + /// Example: + /// @code + /// std::vector permutation = {1, 0}; // swap the first and + /// // the second axes + /// dnnl::memory::desc in_md( + /// {2, 3}, data_type, memory::format_tag::ab); + /// dnnl::memory::desc expect_out_md( + /// {3, 2}, data_type, memory::format_tag::ba); + /// + /// assert(in_md.permute_axes(permutation) == expect_out_md); + /// @endcode + /// + /// @param permutation Axes permutation. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case a + /// zero memory descriptor will be returned. This flag is optional + /// and defaults to false. + /// @returns A new memory descriptor with new dimensions. + desc permute_axes(const std::vector &permutation, + bool allow_empty = false) const { + validate_dims(permutation, get_ndims()); + dnnl_memory_desc_t out_md = nullptr; + dnnl_status_t status = dnnl_memory_desc_permute_axes( + &out_md, get(), permutation.data()); + if (!allow_empty) + error::wrap_c_api(status, + "could not permute axes of a memory descriptor"); + return desc(out_md); + } + + /// Returns a number of dimensions of the memory descriptor. + /// + /// @returns A number of dimensions. + int get_ndims() const { return query_s32(query::ndims_s32); } + + /// Returns padded dimensions of the memory descriptor. + /// + /// @returns A copy of the padded dimensions vector. + memory::dims get_padded_dims() const { + return query_dims(query::padded_dims); + } + + /// Returns padded offsets of the memory descriptor. + /// + /// @returns A copy of the padded offsets vector. + memory::dims get_padded_offsets() const { + return query_dims(query::padded_offsets); + } + + /// Returns a submemory offset of the memory descriptor. + /// + /// @returns A submemory offset. + memory::dim get_submemory_offset() const { + dnnl_dim_t submemory_offset; + dnnl_status_t status = dnnl_memory_desc_query( + get(), dnnl_query_submemory_offset_s64, &submemory_offset); + return status == dnnl_success ? submemory_offset : 0; + } + + /// Returns strides of the memory descriptor. + /// + /// @note + /// This API is only applicable to memory descriptors with format + /// kind #dnnl_blocked. + /// + /// @returns A copy of the strides vector. + /// @returns An empty #dnnl::memory::dims if the memory descriptor + /// does not have strides. + memory::dims get_strides() const { return query_dims(query::strides); } + + /// Returns a number of inner blocks of the memory descriptor. + /// + /// @note + /// This API is only applicable to memory descriptors with format + /// kind #dnnl_blocked. + /// + /// @returns A number of inner blocks. + int get_inner_nblks() const { + return query_s32(query::inner_nblks_s32); + } + + /// Returns inner blocks of the memory descriptor. + /// + /// @note + /// This API is only applicable to memory descriptors with format + /// kind #dnnl_blocked. + /// + /// @returns A copy of the inner blocks vector. + /// @returns An empty #dnnl::memory::dims if the memory descriptor + /// does not have inner blocks. + memory::dims get_inner_blks() const { + return query_dims(query::inner_blks); + } + + /// Returns inner indices of the memory descriptor. + /// + /// @note + /// This API is only applicable to memory descriptors with format + /// kind #dnnl_blocked. + /// + /// @returns A copy of the inner indices vector. + /// @returns An empty #dnnl::memory::dims if the memory descriptor + /// does not have inner indices. + memory::dims get_inner_idxs() const { + return query_dims(query::inner_idxs); + } + +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Returns number of handles. + /// + /// @returns A number of handles. + int get_num_handles() const { + int nhandles; + dnnl_status_t status = dnnl_memory_desc_query_v2( + get(), dnnl_query_num_handles_s32, 0, &nhandles); + return status == dnnl_success ? nhandles : 0; + } + + /// Returns a number of non-zero entries of the memory descriptor. + /// + /// @returns A number non-zero entries. + dim get_nnz() const { + dnnl_dim_t nnz; + dnnl_status_t status = dnnl_memory_desc_query_v2( + get(), dnnl_query_nnz_s64, 0, &nnz); + return status == dnnl_success ? nnz : 0; + } + + /// Returns the sparse encoding of the memory descriptor. + /// + /// @returns the sparse encoding kind. + memory::sparse_encoding get_sparse_encoding() const { + dnnl_sparse_encoding_t sparse_encoding; + dnnl_status_t status = dnnl_memory_desc_query_v2( + get(), dnnl_query_sparse_encoding, 0, &sparse_encoding); + return status == dnnl_success + ? static_cast( + sparse_encoding) + : dnnl::memory::sparse_encoding::undef; + } + + /// Returns the data type of the memory descriptor. + /// + /// @returns The data type. + memory::data_type get_data_type(int index = 0) const { + return query_data_type(query::data_type, index); + } +#else + /// Returns the data type of the memory descriptor. + /// + /// @returns The data type. + memory::data_type get_data_type() const { + return query_data_type(query::data_type); + } +#endif + + /// Returns the format kind of the memory descriptor. + /// + /// @returns the format kind. + memory::format_kind get_format_kind() const { + dnnl_format_kind_t format_kind; + dnnl_status_t status = dnnl_memory_desc_query( + get(), dnnl_query_format_kind, &format_kind); + return status == dnnl_success + ? static_cast(format_kind) + : dnnl::memory::format_kind::undef; + } + + /// Returns dimensions of the memory descriptor. + /// + /// Potentially expensive due to the data copy involved. + /// @returns A copy of the dimensions vector. + memory::dims get_dims() const { return query_dims(query::dims); } + +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Returns size of the memory descriptor in bytes. + /// @param index Data index. Defaults to 0. + /// @returns The number of bytes required to allocate a memory buffer + /// for data with a particular @p index described by this memory + /// descriptor including the padding area. + size_t get_size(int index = 0) const { + return dnnl_memory_desc_get_size_v2(get(), index); + } +#else + /// Returns size of the memory descriptor in bytes. + /// @returns The number of bytes required to allocate a memory buffer + /// for the memory object described by this memory descriptor + /// including the padding area. + size_t get_size() const { return dnnl_memory_desc_get_size(get()); } +#endif + + /// Returns a binary blob associated with the given memory descriptor + /// @returns The memory descriptor blob associated with the memory descriptor + std::vector get_blob() { + size_t size; + dnnl_status_t status + = dnnl_memory_desc_get_blob(nullptr, &size, get()); + error::wrap_c_api( + status, "could not get memory descriptor blob size"); + + std::vector out_blob(size); + status = dnnl_memory_desc_get_blob(out_blob.data(), &size, get()); + error::wrap_c_api(status, "could not get memory descriptor blob"); + return out_blob; + } + + /// Checks whether the memory descriptor is zero (empty). + /// @returns @c true if the memory descriptor describes an empty + /// memory and @c false otherwise. + bool is_zero() const { return get_ndims() == 0; } + + /// An equality operator. + /// @param other Another memory descriptor. + /// @returns Whether this and the other memory descriptors have + /// the same format tag, dimensions, strides, blocking, etc. + bool operator==(const desc &other) const { + return dnnl_memory_desc_equal(get(), other.get()) != 0; + } + + /// An inequality operator. + /// @param other Another memory descriptor. + /// @returns Whether this and the other memory descriptors describe + /// different memory. + bool operator!=(const desc &other) const { return !operator==(other); } + + private: +#ifdef DNNL_EXPERIMENTAL_SPARSE + memory::data_type query_data_type(query what, int index) const { + dnnl_data_type_t data_type; + dnnl_status_t status = dnnl_memory_desc_query_v2( + get(), dnnl::convert_to_c(what), index, &data_type); + return status == dnnl_success + ? static_cast(data_type) + : dnnl::memory::data_type::undef; + } +#else + memory::data_type query_data_type(query what) const { + dnnl_data_type_t data_type; + dnnl_status_t status = dnnl_memory_desc_query( + get(), dnnl::convert_to_c(what), &data_type); + return status == dnnl_success + ? static_cast(data_type) + : dnnl::memory::data_type::undef; + } +#endif + + int query_s32(query what) const { + int res; + dnnl_status_t status = dnnl_memory_desc_query( + get(), dnnl::convert_to_c(what), &res); + return status == dnnl_success ? res : 0; + } + + memory::dims query_dims(query what) const { + dnnl_dims_t *c_dims; + dnnl_status_t status = dnnl_memory_desc_query( + get(), dnnl::convert_to_c(what), &c_dims); + + const int ndims + = (what == query::inner_idxs || what == query::inner_blks) + ? get_inner_nblks() + : get_ndims(); + + return status == dnnl_success + ? memory::dims(*c_dims, *c_dims + ndims) + : memory::dims {}; + } + }; + + /// Default constructor. + /// + /// Constructs an empty memory object, which can be used to indicate + /// absence of a parameter. + memory() = default; + +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Constructs a memory object. + /// + /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory + /// object will have the underlying buffer set. In this case, the buffer + /// will be initialized as if #dnnl::memory::set_data_handle() had been + /// called. + /// + /// @sa memory::set_data_handle() + /// + /// @param md Memory descriptor. + /// @param aengine Engine to store the data on. + /// @param handle Handle of the memory buffer to use. + /// - A pointer to the user-allocated buffer. In this case the library + /// doesn't own the buffer. + /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to + /// allocate the buffer for the memory object. In this case the + /// library owns the buffer. + /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying + /// buffer. + memory(const desc &md, const engine &aengine, void *handle) + : memory(md, aengine, std::vector {handle}) {} + + /// Constructs a memory object with multiple handles. + /// + /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory + /// object will have the underlying buffer set. In this case, the buffer + /// will be initialized as if #dnnl::memory::set_data_handle() had been + /// called. + /// + /// @sa memory::set_data_handle() + /// + /// @param md Memory descriptor. + /// @param aengine Engine to store the data on. + /// @param handles Handles of the memory buffers to use. + /// For each element of the @p handles vector the following applies: + /// - A pointer to the user-allocated buffer. In this case the library + /// doesn't own the buffer. + /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to + /// allocate the buffer for the memory object. In this case the + /// library owns the buffer. + /// - #DNNL_MEMORY_NONE Instructs the library to skip allocation of the + /// memory buffer. + memory(const desc &md, const engine &aengine, std::vector handles) { + dnnl_memory_t result; + dnnl_status_t status = dnnl_memory_create_v2(&result, md.get(), + aengine.get(), (int)handles.size(), handles.data()); + error::wrap_c_api(status, "could not create a memory object"); + reset(result); + } + + /// Constructs a memory object. + /// + /// The underlying buffer(s) for the memory will be allocated by the + /// library. + /// @param md Memory descriptor. + /// @param aengine Engine to store the data on. + memory(const desc &md, const engine &aengine) { + dnnl_status_t status; + dnnl_memory_t result; + const int nhandles = md.get_num_handles(); + + std::vector handles(nhandles, DNNL_MEMORY_ALLOCATE); + status = dnnl_memory_create_v2(&result, md.get(), aengine.get(), + (int)handles.size(), handles.data()); + + error::wrap_c_api(status, "could not create a memory object"); + reset(result); + } +#else + /// Constructs a memory object. + /// + /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory + /// object will have the underlying buffer set. In this case, the buffer + /// will be initialized as if #dnnl::memory::set_data_handle() had been + /// called. + /// + /// @sa memory::set_data_handle() + /// + /// @param md Memory descriptor. + /// @param aengine Engine to store the data on. + /// @param handle Handle of the memory buffer to use. + /// - A pointer to the user-allocated buffer. In this case the library + /// doesn't own the buffer. + /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to + /// allocate the buffer for the memory object. In this case the + /// library owns the buffer. + /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying + /// buffer. + memory(const desc &md, const engine &aengine, void *handle) { + dnnl_memory_t result; + error::wrap_c_api( + dnnl_memory_create(&result, md.get(), aengine.get(), handle), + "could not create a memory object"); + reset(result); + } + + /// Constructs a memory object. + /// + /// The underlying buffer for the memory will be allocated by the library. + /// + /// @param md Memory descriptor. + /// @param aengine Engine to store the data on. + memory(const desc &md, const engine &aengine) + : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {} +#endif + + /// Returns the associated memory descriptor. + desc get_desc() const { + const_dnnl_memory_desc_t cdesc; + error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc), + "could not get a memory descriptor from a memory object"); + dnnl_memory_desc_t cloned_md = nullptr; + error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), + "could not clone a memory descriptor"); + return desc(cloned_md); + } + + /// Returns the associated engine. + engine get_engine() const { + dnnl_engine_t c_engine; + error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine), + "could not get an engine from a memory object"); + return engine(c_engine, true); + } + +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Returns an underlying memory buffer that corresponds to the given index. + /// + /// On the CPU engine, or when using USM, this is a pointer to the + /// allocated memory. + void *get_data_handle(int index = 0) const { + void *handle; + error::wrap_c_api(dnnl_memory_get_data_handle_v2(get(), &handle, index), + "could not get a native handle from a memory object"); + return handle; + } + + /// Sets an underlying memory buffer that corresponds to the given index. + /// + /// @param handle Memory buffer to use. On the CPU engine or when USM is + /// used, the memory buffer is a pointer to the actual data. For OpenCL + /// it is a cl_mem. It must have at least + /// #dnnl::memory::desc::get_size() bytes allocated. + /// @param index Memory index to attach the buffer. Defaults to 0. + void set_data_handle(void *handle, int index = 0) const { + error::wrap_c_api(dnnl_memory_set_data_handle_v2(get(), handle, index), + "could not set native handle of a memory object"); + } + + /// Maps a memory object and returns a host-side pointer to a memory + /// buffer with a copy of its contents. The memory buffer corresponds to + /// the given index. + /// + /// Mapping enables read/write directly from/to the memory contents for + /// engines that do not support direct memory access. + /// + /// Mapping is an exclusive operation - a memory object cannot be used in + /// other operations until it is unmapped via #dnnl::memory::unmap_data() + /// call. + /// + /// @note + /// Any primitives working with the memory should be completed before + /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the + /// corresponding execution stream. + /// + /// @note + /// The map_data and unmap_data functions are provided mainly for + /// debug and testing purposes and their performance may be suboptimal. + /// + /// @tparam T Data type to return a pointer to. + /// @param index Index of the buffer. Defaults to 0. + /// @returns Pointer to the mapped memory. + template + T *map_data(int index = 0) const { + void *mapped_ptr; + error::wrap_c_api(dnnl_memory_map_data_v2(get(), &mapped_ptr, index), + "could not map memory object data"); + return static_cast(mapped_ptr); + } + + /// Unmaps a memory object and writes back any changes made to the + /// previously mapped memory buffer. The memory buffer corresponds to + /// the given index. + /// + /// @note + /// The map_data and unmap_data functions are provided mainly for + /// debug and testing purposes and their performance may be + /// suboptimal. + /// + /// @param mapped_ptr A pointer previously returned by + /// #dnnl::memory::map_data(). + /// @param index Index of the buffer. Defaults to 0. + void unmap_data(void *mapped_ptr, int index = 0) const { + error::wrap_c_api(dnnl_memory_unmap_data_v2(get(), mapped_ptr, index), + "could not unmap memory object data"); + } +#else + /// Returns the underlying memory buffer. + /// + /// On the CPU engine, or when using USM, this is a pointer to the + /// allocated memory. + void *get_data_handle() const { + void *handle; + error::wrap_c_api(dnnl_memory_get_data_handle(get(), &handle), + "could not get a native handle from a memory object"); + return handle; + } + + /// Sets the underlying memory buffer. + /// + /// @param handle Memory buffer to use. On the CPU engine or when USM is + /// used, the memory buffer is a pointer to the actual data. For OpenCL + /// it is a cl_mem. It must have at least + /// #dnnl::memory::desc::get_size() bytes allocated. + void set_data_handle(void *handle) const { + error::wrap_c_api(dnnl_memory_set_data_handle(get(), handle), + "could not set native handle of a memory object"); + } + + /// Maps a memory object and returns a host-side pointer to a memory + /// buffer with a copy of its contents. + /// + /// Mapping enables read/write directly from/to the memory contents for + /// engines that do not support direct memory access. + /// + /// Mapping is an exclusive operation - a memory object cannot be used in + /// other operations until it is unmapped via #dnnl::memory::unmap_data() + /// call. + /// + /// @note + /// Any primitives working with the memory should be completed before + /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the + /// corresponding execution stream. + /// + /// @note + /// The map_data and unmap_data functions are provided mainly for + /// debug and testing purposes and their performance may be suboptimal. + /// + /// @tparam T Data type to return a pointer to. + /// @returns Pointer to the mapped memory. + template + T *map_data() const { + void *mapped_ptr; + error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr), + "could not map memory object data"); + return static_cast(mapped_ptr); + } + + /// Unmaps a memory object and writes back any changes made to the + /// previously mapped memory buffer. + /// + /// @note + /// The map_data and unmap_data functions are provided mainly for + /// debug and testing purposes and their performance may be + /// suboptimal. + /// + /// @param mapped_ptr A pointer previously returned by + /// #dnnl::memory::map_data(). + void unmap_data(void *mapped_ptr) const { + error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr), + "could not unmap memory object data"); + } +#endif + + static dnnl_data_type_t convert_to_c(data_type adata_type) { + return static_cast(adata_type); + } + static dnnl_format_tag_t convert_to_c(format_tag format) { + return static_cast(format); + } +}; + +inline bool operator==(dnnl_data_type_t a, memory::data_type b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(dnnl_data_type_t a, memory::data_type b) { + return !(a == b); +} +inline bool operator==(memory::data_type a, dnnl_data_type_t b) { + return b == a; +} +inline bool operator!=(memory::data_type a, dnnl_data_type_t b) { + return !(a == b); +} + +inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) { + return !(a == b); +} +inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) { + return b == a; +} +inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) { + return !(a == b); +} + +/// @} dnnl_api_memory + +/// @addtogroup dnnl_api_primitives +/// @{ +/// @addtogroup dnnl_api_attributes Attributes +/// +/// A container for parameters that extend primitives behavior. +/// +/// @{ + +/// @cond DO_NOT_DOCUMENT_THIS +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_post_ops_t p) { + return dnnl_post_ops_destroy(p); + } +}; +/// @endcond + +/// Post-ops. +/// +/// Post-ops are computations executed after the main primitive computations +/// and are attached to the primitive via primitive attributes. +/// +/// @sa @ref dev_guide_attributes_post_ops +/// +struct post_ops : public handle { + using handle::handle; + + /// Constructs an empty sequence of post-ops. + post_ops() { + dnnl_post_ops_t result; + error::wrap_c_api( + dnnl_post_ops_create(&result), "could not create post-ops"); + reset(result); + } + + /// Creates post-ops primitive attribute from a C API ::dnnl_post_ops_t + /// handle. The resulting handle is not weak and the C handle will be + /// destroyed during the destruction of the C++ object. + /// + /// @param post_ops The C API post-ops primitive attribute. + post_ops(dnnl_post_ops_t post_ops) : handle(post_ops) {} + + /// Returns the number of post-ops entries. + int len() const { return dnnl_post_ops_len(get()); } + + /// Returns the primitive kind of post-op at entry with a certain index. + /// @param index Index of the post-op to return the kind for. + /// @returns Primitive kind of the post-op at the specified index. + primitive::kind kind(int index) const { + error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments, + "post-ops index is out of range"); + return static_cast( + dnnl_post_ops_get_kind(get(), index)); + } + + /// Appends an accumulation (sum) post-op. Prior to accumulating the + /// result, the previous value will be will be reduced by zero point + /// @p zero_point and multiplied by a scaling factor @p scale. + /// + /// The kind of this post-op is #dnnl::primitive::kind::sum. + /// + /// This feature may improve performance for cases like dequantize the + /// asymmetrically quantized sum's src1 tensor to f32 domain before + /// performing the sum operation by subtracting @p zero_point before the + /// scaling. + /// + /// In the simplest case when the accumulation is the only post-op, + /// the computations will be `dst[:] := scale * (dst[:] - zero_point) + + /// op(...)` instead of `dst[:] := op(...)`. + /// + /// If @p data_type is specified, the original dst tensor will be + /// reinterpreted as a tensor with the provided data type. Because it is a + /// reinterpretation, data_type and dst data type should have the same size. + /// As a result, computations will be `dst[:] <- scale * + /// (as_data_type(dst[:]) - zero_point) + op(...)` instead of + /// `dst[:] <- op(...)`. + /// + /// @note + /// This post-op executes in-place and does not change the + /// destination layout. + /// + /// @param scale Scaling factor. + /// @param zero_point Zero point. + /// @param data_type Data type. + void append_sum(float scale = 1.f, int32_t zero_point = 0, + memory::data_type data_type = memory::data_type::undef) { + error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale, zero_point, + memory::convert_to_c(data_type)), + "could not append a sum post-op"); + } + + /// Returns the parameters of an accumulation (sum) post-op. + /// + /// @param index Index of the sum post-op. + /// @param scale Scaling factor of the sum post-op. + void get_params_sum(int index, float &scale) const { + error::wrap_c_api(dnnl_post_ops_get_params_sum( + get(), index, &scale, nullptr, nullptr), + "could not get parameters of a sum post-op"); + } + + /// Returns the parameters of an accumulation (sum) post-op. + /// + /// @param index Index of the sum post-op. + /// @param scale Scaling factor of the sum post-op. + /// @param data_type Data type of the sum post-op. + void get_params_sum( + int index, float &scale, memory::data_type &data_type) const { + dnnl_data_type_t c_data_type; + error::wrap_c_api(dnnl_post_ops_get_params_sum( + get(), index, &scale, nullptr, &c_data_type), + "could not get parameters of a sum post-op"); + data_type = static_cast(c_data_type); + } + + /// Returns the parameters of an accumulation (sum) post-op. + /// + /// @param index Index of the sum post-op. + /// @param scale Scaling factor of the sum post-op. + /// @param zero_point Single scalar int32_t value of zeropoint. + /// @param data_type Data type of the sum post-op. + void get_params_sum(int index, float &scale, int32_t &zero_point, + memory::data_type &data_type) const { + dnnl_data_type_t c_data_type; + error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale, + &zero_point, &c_data_type), + "could not get parameters of a sum post-op"); + data_type = static_cast(c_data_type); + } + + /// Appends an elementwise post-op. + /// + /// The kind of this post-op is #dnnl::primitive::kind::eltwise. + /// + /// In the simplest case when the elementwise is the only post-op, the + /// computations would be `dst[:] := eltwise_op (op(...))` instead + /// of `dst[:] <- op(...)`, where eltwise_op is configured with the given + /// parameters. + /// + /// @param aalgorithm Elementwise algorithm. + /// @param alpha Alpha parameter for the elementwise algorithm. + /// @param beta Beta parameter for the elementwise algorithm. + void append_eltwise(algorithm aalgorithm, float alpha, float beta) { + error::wrap_c_api(dnnl_post_ops_append_eltwise( + get(), convert_to_c(aalgorithm), alpha, beta), + "could not append an elementwise post-op"); + } + + /// Returns parameters of an elementwise post-op. + /// + /// @param index Index of the post-op. + /// @param aalgorithm Output elementwise algorithm kind. + /// @param alpha Output alpha parameter for the elementwise algorithm. + /// @param beta Output beta parameter for the elementwise algorithm. + void get_params_eltwise( + int index, algorithm &aalgorithm, float &alpha, float &beta) const { + dnnl_alg_kind_t c_alg; + error::wrap_c_api(dnnl_post_ops_get_params_eltwise( + get(), index, &c_alg, &alpha, &beta), + "could not get parameters of an elementwise post-op"); + aalgorithm = static_cast(c_alg); + } + + /// Appends a depthwise post-op convolution. + /// + /// This post-op can only be fused with a 2D 1x1 convolution (convolution + /// with weights spatial dimension equal to 1 i.e., kh=kw=1). + /// + /// The kind of this post-op is #dnnl_convolution. + /// + /// The number of outputs for primitive remain same as before. The output + /// spatial size can be derived as below: + /// + /// output_height = ceil(output_height_1x1_convolution, stride) + /// output_width = ceil(output_width_1x1_convolution, stride) + /// + /// See @ref dev_guide_attributes_post_ops_depthwise and + /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info. + /// + /// @param weights_data_type Weights data type of depthwise post-op + /// @param bias_data_type Bias data type of depthwise post-op + /// @param dst_data_type Output data type of depthwise post-op + /// @param kernel_size Size of kernel of depthwise post-op + /// @param stride_size Size of stride of depthwise post-op + /// @param padding_l_size Size of left and top paddings of depthwise post-op + void append_dw(memory::data_type weights_data_type, + memory::data_type bias_data_type, memory::data_type dst_data_type, + memory::dim kernel_size, memory::dim stride_size, + memory::dim padding_l_size) { + + error::wrap_c_api(dnnl_post_ops_append_dw(get(), + memory::convert_to_c(weights_data_type), + memory::convert_to_c(bias_data_type), + memory::convert_to_c(dst_data_type), + kernel_size, stride_size, padding_l_size), + "could not append depthwise post-op"); + } + + /// Returns the parameters of an depthwise post-op. + /// + /// @param index Index of the elementwise post-op. + /// @param weights_data_type Weights data type of depthwise post-op + /// @param bias_data_type Bias data type of depthwise post-op + /// @param dst_data_type Output data type of depthwise post-op + /// @param kernel_size Size of kernel of depthwise post-op + /// @param stride_size Size of stride of depthwise post-op + /// @param padding_l_size Size of left and top paddings of depthwise post-op + void get_params_dw(int index, memory::data_type &weights_data_type, + memory::data_type &bias_data_type, memory::data_type &dst_data_type, + memory::dim &kernel_size, memory::dim &stride_size, + memory::dim &padding_l_size) const { + + dnnl_data_type_t c_weights_data_type; + dnnl_data_type_t c_bias_data_type; + dnnl_data_type_t c_dst_data_type; + dnnl_dim_t c_kernel_size; + dnnl_dim_t c_stride_size; + dnnl_dim_t c_padding_l_size; + error::wrap_c_api( + dnnl_post_ops_get_params_dw(get(), index, &c_weights_data_type, + &c_bias_data_type, &c_dst_data_type, &c_kernel_size, + &c_stride_size, &c_padding_l_size), + "could not get parameters of depthwise post-op"); + + weights_data_type = static_cast(c_weights_data_type); + bias_data_type = static_cast(c_bias_data_type); + dst_data_type = static_cast(c_dst_data_type); + kernel_size = c_kernel_size; + stride_size = c_stride_size; + padding_l_size = c_padding_l_size; + } + + /// Appends a binary post-op. + /// + /// The kind of this post operation is #dnnl_binary. + /// + /// In the simplest case when the binary is the only post operation, the + /// computations would be: + /// + /// dst[:] <- binary_op (dst[:], another_input[:]) + /// + /// where binary_op is configured with the given parameters. binary_op + /// supports broadcast semantics for a second operand. + /// + /// @param aalgorithm Binary algorithm for the post-op. + /// @param src1_desc Memory descriptor of a second operand. + void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) { + error::wrap_c_api(dnnl_post_ops_append_binary(get(), + convert_to_c(aalgorithm), src1_desc.get()), + "could not append a binary post-op"); + } + + /// Returns the parameters of a binary post-op. + /// + /// @param index Index of the binary post-op. + /// @param aalgorithm Output binary algorithm kind. + /// @param src1_desc Output memory descriptor of a second operand. + void get_params_binary( + int index, algorithm &aalgorithm, memory::desc &src1_desc) const { + dnnl_alg_kind_t c_alg; + const_dnnl_memory_desc_t cdesc; + error::wrap_c_api( + dnnl_post_ops_get_params_binary(get(), index, &c_alg, &cdesc), + "could not get parameters of a binary post-op"); + aalgorithm = static_cast(c_alg); + dnnl_memory_desc_t cloned_md = nullptr; + error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), + "could not clone a memory descriptor"); + src1_desc = memory::desc(cloned_md); + } + + /// Appends a prelu forward post-op. + /// + /// The kind of this post-op is #dnnl::primitive::kind::prelu. + /// + /// The post-op can be defined as: + /// + /// dst[:] <- prelu(dst[:], weights[:]) + /// prelu: + /// dst[:] <- dst[:] if dst[:] > 0 + /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0 + /// + /// + /// Example usage: + /// @code + /// int mb = 32, oc = 32, + /// oh = 14, ow = 14; // convolution output params + /// // unique weights per output channel + /// vector weights = { ... }; + /// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... + /// + /// // construct a convolution descriptor + /// dnnl::convolution::desc conv_d; + /// + /// dnnl::primitive_attr attr; + /// attr.append_prelu(1 << oc_dim); + /// + /// dnnl::primitive_desc conv_pd(conv_d, attr, engine); + /// memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data()); + /// + /// std::unordered_map conv_args; + /// + /// conv_args.insert( + /// {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights}) + /// @endcode + /// + /// @note + /// The order of dimensions does not depend on how elements are laid + /// out in memory. For example: + /// - for a 2D CNN activations tensor the order is always (n, c) + /// - for a 4D CNN activations tensor the order is always (n, c, h, w) + /// - for a 5D CNN weights tensor the order is always + /// (g, oc, ic, kh, kw) + /// + /// Prelu weights tensor is passed in runtime execution phase. Prelu + /// weights tensor data type is implicitly assumed as f32 using plain + /// layout (a, ab, acb, acdb, acdeb). + /// + /// @param mask Defines the correspondence between the output tensor + /// dimensions and the prelu weights tensor. The set i-th bit indicates + /// that a dedicated weights value is used for each index along that + /// dimension. Set the mask to 0 to use a common weights value + /// for the whole output tensor. + void append_prelu(int mask) { + error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask), + "could not append a prelu post-op"); + } + + /// Returns the parameters of a prelu post-op. + /// + /// @param index Index of the prelu post-op. + /// @param mask Weights mask of prelu post-op. + void get_params_prelu(int index, int &mask) const { + error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask), + "could not get parameters of a binary post-op"); + } +}; + +/// @cond DO_NOT_DOCUMENT_THIS +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_primitive_attr_t p) { + return dnnl_primitive_attr_destroy(p); + } +}; +/// @endcond + +/// Primitive attributes. +/// +/// @sa @ref dev_guide_attributes +struct primitive_attr : public handle { + using handle::handle; + + /// Constructs default (empty) primitive attributes. + primitive_attr() { + dnnl_primitive_attr_t result; + error::wrap_c_api(dnnl_primitive_attr_create(&result), + "could not create primitive attribute"); + reset(result); + } + + /// Creates primitive attributes from a C API ::dnnl_primitive_attr_t + /// handle. The resulting handle is not weak and the C handle will be + /// destroyed during the destruction of the C++ object. + /// + /// @param attr The C API primitive attributes. + primitive_attr(dnnl_primitive_attr_t attr) + : handle(attr) {} + + /// Returns the parameters of a dropout attribute. + /// + /// @param mask_desc Output memory descriptor of a dropout mask. + void get_dropout(memory::desc &mask_desc) const { + const_dnnl_memory_desc_t cdesc; + error::wrap_c_api(dnnl_primitive_attr_get_dropout(get(), &cdesc), + "could not get parameters of a dropout attribute"); + dnnl_memory_desc_t cloned_md = nullptr; + error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), + "could not clone a memory descriptor"); + mask_desc = memory::desc(cloned_md); + } + + /// Sets dropout probability. + /// + /// @param mask_desc Output memory descriptor of a dropout mask. + void set_dropout(const memory::desc &mask_desc) { + error::wrap_c_api( + dnnl_primitive_attr_set_dropout(get(), mask_desc.get()), + "could not set dropout primitive attribute"); + } + + /// Returns the fpmath mode + fpmath_mode get_fpmath_mode() const { + dnnl_fpmath_mode_t result; + error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result), + "could not get fpmath mode primitive attribute"); + return fpmath_mode(result); + } + + /// Returns the fpmath mode + /// + /// @param mode Specified fpmath mode. + /// @param apply_to_int Use floating-point arithmetic for integer primitives. + void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const { + dnnl_fpmath_mode_t c_mode; + int c_apply_to_int; + error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode_v2( + get(), &c_mode, &c_apply_to_int), + "could not get fpmath mode primitive attribute"); + mode = fpmath_mode(c_mode); + apply_to_int = static_cast(c_apply_to_int); + } + + /// Sets fpmath mode. + /// + /// @param mode Specified fpmath mode. + /// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives. + void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) { + error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode_v2(get(), + dnnl::convert_to_c(mode), apply_to_int), + "could not set fpmath mode primitive attribute"); + } + + /// Returns the accumulation mode + accumulation_mode get_accumulation_mode() const { + dnnl_accumulation_mode_t result; + error::wrap_c_api( + dnnl_primitive_attr_get_accumulation_mode(get(), &result), + "could not get accumulation mode primitive attribute"); + return accumulation_mode(result); + } + + /// Sets accumulation mode. + /// + /// @param mode Specified accumulation mode. + void set_accumulation_mode(accumulation_mode mode) { + error::wrap_c_api(dnnl_primitive_attr_set_accumulation_mode( + get(), dnnl::convert_to_c(mode)), + "could not set accumulation mode primitive attribute"); + } + + /// Returns the deterministic attribute value + bool get_deterministic() const { + int result; + error::wrap_c_api(dnnl_primitive_attr_get_deterministic(get(), &result), + "could not get deterministic primitive attribute"); + return static_cast(result); + } + + /// Sets deterministic attribute value + /// + /// @param value Specified deterministic mode. + void set_deterministic(bool value) { + error::wrap_c_api(dnnl_primitive_attr_set_deterministic( + get(), static_cast(value)), + "could not set deterministic primitive attribute"); + } + + /// Returns the rounding mode attribute value + /// + /// @param arg Argument for which rounding mode query applies. + /// @returns The rounding mode applied to the specified argument. + rounding_mode get_rounding_mode(int arg) const { + dnnl_rounding_mode_t result; + error::wrap_c_api(dnnl_primitive_attr_get_rounding(get(), arg, &result), + "could not get rounding mode primitive attribute"); + return rounding_mode(result); + } + + /// Sets the rounding mode attribute value for a given argument + /// + /// @param arg Argument for which to set rounding mode. + /// @param mode Rounding mode to apply. + void set_rounding_mode(int arg, rounding_mode mode) { + error::wrap_c_api(dnnl_primitive_attr_set_rounding( + get(), arg, convert_to_c(mode)), + "could not set rounding mode primitive attribute"); + } + + /// Returns the scratchpad mode. + scratchpad_mode get_scratchpad_mode() const { + dnnl_scratchpad_mode_t result; + error::wrap_c_api( + dnnl_primitive_attr_get_scratchpad_mode(get(), &result), + "could not get scratchpad mode primitive attribute"); + return scratchpad_mode(result); + } + + /// Sets scratchpad mode. + /// + /// @param mode Specified scratchpad mode. + void set_scratchpad_mode(scratchpad_mode mode) { + error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode( + get(), dnnl::convert_to_c(mode)), + "could not set scratchpad mode primitive attribute"); + } + + /// Sets scaling factors for primitive operations for a given memory + /// argument. The scaling factors must be passed at execution time + /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg. + /// + /// @sa dnnl_primitive_attr_set_scales_mask + /// + /// @param arg Parameter argument index as passed to the + /// primitive::execute() call. + /// @param mask Scaling factors correspondence mask that defines the + /// correspondence between the tensor dimensions and the @p scales + /// vector. The set i-th bit indicates that a dedicated scaling factor + /// is used for each index along that dimension. Set the mask to 0 to + /// use a common scaling factor for the whole output tensor. + void set_scales_mask(int arg, int mask) { + error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask), + "could not set scales primitive attribute"); + } + + /// Sets scaling factors for primitive operations for a given memory + /// argument. The scaling factors must be passed at execution time + /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg. + /// + /// @sa dnnl_primitive_attr_set_scales + /// + /// @param arg Parameter argument index as passed to the + /// primitive::execute() call. + /// @param mask Scales correspondence mask that defines the + /// correspondence between the tensor dimensions and the @p + /// scales vector. The set i-th bit indicates that a dedicated + /// scale is used for each index along that dimension. Set the + /// mask to 0 to use a common scale for the whole output tensor. + /// @param groups Scaling factors correspondence groups that define the + /// correspondence between the tensor dimensions and the scales array. + /// The set i-th dimension indicates a number of groups of scaling + /// factors used for that logical dimension in a memory indicated by @p arg. + /// @param data_type Scaling factors data_type. + void set_scales(int arg, int mask, const memory::dims &groups, + memory::data_type data_type = memory::data_type::f32) { + error::wrap_c_api(dnnl_primitive_attr_set_scales(get(), arg, mask, + (int)groups.size(), groups.data(), + memory::convert_to_c(data_type)), + "could not set scales primitive attribute"); + } + + /// Sets zero points for primitive operations for a given memory argument. + /// The zero points must be passed at execution time as an argument with + /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg. + /// + /// @sa dnnl_primitive_attr_set_zero_points_mask + /// + /// @param arg Parameter argument index as passed to the + /// primitive::execute() call. + /// @param mask Zero point correspondence mask that defines the + /// correspondence between the tensor dimensions and the @p + /// zero_points vector. The set i-th bit indicates that a dedicated + /// zero point is used for each index along that dimension. Set the + /// mask to 0 to use a common zero point for the whole output tensor. + void set_zero_points_mask(int arg, int mask) { + error::wrap_c_api( + dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask), + "could not set zero points primitive attribute"); + } + + /// Sets zero points for primitive operations for a given memory argument. + /// The zero points must be passed at execution time as an argument with + /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg. + /// + /// @sa dnnl_primitive_attr_set_zero_points + /// + /// @param arg Parameter argument index as passed to the + /// primitive::execute() call. + /// @param mask Zero point correspondence mask that defines the + /// correspondence between the tensor dimensions and the @p + /// zero_points vector. The set i-th bit indicates that a dedicated + /// zero point is used for each index along that dimension. Set the + /// mask to 0 to use a common zero point for the whole output tensor. + /// @param groups Zero point factors correspondence groups that define the + /// correspondence between the tensor dimensions and the zero_points array. + /// The set i-th dimension indicates a number of groups of zero point + /// factors used for that logical dimension in a memory indicated by @p arg. + /// @param data_type Zero point factors data_type. + void set_zero_points(int arg, int mask, const memory::dims &groups, + memory::data_type data_type = memory::data_type::s32) { + error::wrap_c_api(dnnl_primitive_attr_set_zero_points(get(), arg, mask, + (int)groups.size(), groups.data(), + memory::convert_to_c(data_type)), + "could not set zero points primitive attribute"); + } + + /// Returns post-ops previously set via set_post_ops(). + /// + /// @returns Post-ops. + const post_ops get_post_ops() const { + const_dnnl_post_ops_t const_c_post_ops; + error::wrap_c_api( + dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops), + "could not get post-ops primitive attribute"); + dnnl_post_ops_t c_post_ops; + error::wrap_c_api(dnnl_post_ops_clone(&c_post_ops, const_c_post_ops), + "could not clone post-ops primitive attribute"); + return post_ops(c_post_ops); + } + + /// Sets post-ops. + /// + /// @note + /// There is no way to check whether the post-ops would be supported + /// by the target primitive. Any error will be reported + /// by the respective primitive descriptor constructor. + /// + /// @param ops Post-ops object to copy post-ops from. + void set_post_ops(const post_ops ops) { + error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()), + "could not set post-ops primitive attribute"); + } + + /// Sets quantization scale and shift parameters for RNN data tensors. + /// + /// For performance reasons, the low-precision configuration of the RNN + /// primitives expect input activations to have the unsigned 8-bit integer + /// data type. The scale and shift parameters are used to quantize + /// floating-point data to unsigned integer and must be passed to the RNN + /// primitive using attributes. + /// + /// The quantization formula is `scale * data + shift`. + /// + /// Example usage: + /// @code + /// // RNN parameters + /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + /// // Activations quantization parameters + /// float scale = 63.f, shift = 64.f; + /// + /// primitive_attr attr; + /// + /// // Set scale and shift for int8 quantization of activation + /// attr.set_rnn_data_qparams(scale, shift); + /// + /// // Create an RNN primitive descriptor. + /// vanilla_rnn_forward::primitive_desc rnn_d( + /// engine, /* arguments */, attr); + /// @endcode + /// + /// @note + /// Quantization scale and shift are common for src_layer, src_iter, + /// dst_iter, and dst_layer. + /// + /// @param scale The value to scale the data by. + /// @param shift The value to shift the data by. + void set_rnn_data_qparams(float scale, float shift) { + error::wrap_c_api( + dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift), + "could not set RNN data quantization parameters primitive " + "attribute"); + } + + /// Returns the quantization scale and shift parameters for RNN data + /// tensors. + /// + /// @note + /// Quantization scale and shift are common for src_layer, src_iter, + /// dst_iter, and dst_layer. + /// + /// @param scale The value to scale the data by. + /// @param shift The value to shift the data by. + void get_rnn_data_qparams(float &scale, float &shift) { + float c_scale, c_shift; + error::wrap_c_api(dnnl_primitive_attr_get_rnn_data_qparams( + get(), &c_scale, &c_shift), + "could not set RNN data quantization parameters primitive " + "attribute"); + scale = c_scale; + shift = c_shift; + } + + /// Sets quantization scaling factors for RNN weights tensors. The + /// low-precision configuration of the RNN primitives expect input weights + /// to use the signed 8-bit integer data type. The scaling factors are + /// used to quantize floating-point data to signed integer and must be + /// passed to RNN primitives using attributes. + /// + /// @note + /// The dimension order is always native and does not depend on the + /// actual layout used. For example, five-dimensional weights always + /// have (l, d, i, g, o) logical dimension ordering. + /// + /// @note + /// Quantization scales are common for weights_layer and + /// weights_iteration + /// + /// @param mask Scaling factors correspondence mask that defines the + /// correspondence between the output tensor dimensions and the @p + /// scales vector. The set i-th bit indicates that a dedicated scaling + /// factor should be used each index along that dimension. Set the + /// mask to 0 to use a common scaling factor for the whole output + /// tensor. + /// @param scales Constant vector of output scaling factors. The following + /// equality must hold: + /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ + /// Violations can only be detected when the attributes are used to + /// create a primitive descriptor. + void set_rnn_weights_qparams(int mask, const std::vector &scales) { + error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(), + (int)scales.size(), mask, scales.data()), + "could not set RNN weights quantization parameters primitive " + "attribute"); + } + + /// Returns the quantization scaling factors for RNN projection weights + /// tensors. + /// + /// @note + /// The dimension order is always native and does not depend on the + /// actual layout used. For example, five-dimensional weights always + /// have (l, d, i, g, o) logical dimension ordering. + /// + /// @param mask Scaling factors correspondence mask that defines the + /// correspondence between the output tensor dimensions and the @p + /// scales vector. The set i-th bit indicates that a dedicated scaling + /// factor should be used each index along that dimension. Set the + /// mask to 0 to use a common scaling factor for the whole output + /// tensor. + /// @param scales Constant vector of output scaling factors. The following + /// equality must hold: + /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ + /// Violations can only be detected when the attributes are used to + /// create a primitive descriptor. + void get_rnn_weights_qparams(int &mask, std::vector &scales) { + dnnl_dim_t count; + int c_mask; + const float *c_scales; + error::wrap_c_api(dnnl_primitive_attr_get_rnn_weights_qparams( + get(), &count, &c_mask, &c_scales), + "could not get primitive RNN weights quantization " + "parameters attributes"); + scales.resize(count); + + mask = c_mask; + for (dnnl_dim_t c = 0; c < count; c++) + scales[c] = c_scales[c]; + } + + /// Sets quantization scaling factors for RNN projection weights tensors. + // The low-precision configuration of the RNN primitives expect input + // weights to use the signed 8-bit integer data type. The scaling factors + // are used to quantize floating-point data to signed integer and must be + /// passed to RNN primitives using attributes. + /// + /// @note + /// The dimension order is always native and does not depend on the + /// actual layout used. For example, five-dimensional weights always + /// have (l, d, i, g, o) logical dimension ordering. + /// + /// @note + /// Quantization scales are common for weights_layer and + /// weights_iteration + /// + /// @param mask Scaling factors correspondence mask that defines the + /// correspondence between the output tensor dimensions and the @p + /// scales vector. The set i-th bit indicates that a dedicated scaling + /// factor should be used each index along that dimension. Set the + /// mask to 0 to use a common scaling factor for the whole output + /// tensor. + /// @param scales Constant vector of output scaling factors. The following + /// equality must hold: + /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ + /// Violations can only be detected when the attributes are used to + /// create a primitive descriptor. + void set_rnn_weights_projection_qparams( + int mask, const std::vector &scales) { + error::wrap_c_api( + dnnl_primitive_attr_set_rnn_weights_projection_qparams( + get(), (int)scales.size(), mask, scales.data()), + "could not set primitive RNN weights projection quantization " + "parameters attributes"); + } + + /// Returns the quantization scaling factors for RNN projection weights + /// tensors. + /// + /// @note + /// The dimension order is always native and does not depend on the + /// actual layout used. For example, five-dimensional weights always + /// have (l, d, i, g, o) logical dimension ordering. + /// + /// @param mask Scaling factors correspondence mask that defines the + /// correspondence between the output tensor dimensions and the @p + /// scales vector. The set i-th bit indicates that a dedicated scaling + /// factor should be used each index along that dimension. Set the + /// mask to 0 to use a common scaling factor for the whole output + /// tensor. + /// @param scales Constant vector of output scaling factors. The following + /// equality must hold: + /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$ + /// Violations can only be detected when the attributes are used to + /// create a primitive descriptor. + void get_rnn_weights_projection_qparams( + int &mask, std::vector &scales) { + dnnl_dim_t count; + int c_mask; + const float *c_scales; + error::wrap_c_api( + dnnl_primitive_attr_get_rnn_weights_projection_qparams( + get(), &count, &c_mask, &c_scales), + "could not get primitive RNN weights projection quantization " + "parameters attributes"); + scales.resize(count); + + mask = c_mask; + for (dnnl_dim_t c = 0; c < count; c++) + scales[c] = c_scales[c]; + } +}; + +/// @} dnnl_api_attributes + +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// Base class for all primitive descriptors. +struct primitive_desc_base : public handle { + using handle::handle; + + /// Default constructor. Produces an empty object. + primitive_desc_base() = default; + + /// Returns the engine of the primitive descriptor. + /// @returns The engine of the primitive descriptor. + engine get_engine() const { return query_engine(query::engine); } + + /// Returns implementation name. + /// @returns The implementation name. + const char *impl_info_str() const { + const char *res; + error::wrap_c_api(dnnl_primitive_desc_query( + get(), dnnl_query_impl_info_str, 0, &res), + "could not retrieve implementation info string from a " + "primitive descriptor"); + return res; + } + + /// Returns a memory::dim value (same as int64_t). + /// @param what The value to query. + /// @returns The result of the query. + memory::dim query_s64(query what) const { + memory::dim res; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl::convert_to_c(what), 0, &res); + return status == dnnl_success ? res : 0; + } + + /// Returns strides. + /// @returns Strides. + /// @returns An empty #dnnl::memory::dims if the primitive does not have + /// a strides parameter. + memory::dims get_strides() const { return query_dims(query::strides); } + + /// Returns dilations. + /// @returns Dilations. + /// @returns An empty #dnnl::memory::dims if the primitive does not have + /// a dilations parameter. + memory::dims get_dilations() const { return query_dims(query::dilations); } + + /// Returns a left padding. + /// @returns A left padding. + /// @returns An empty #dnnl::memory::dims if the primitive does not have + /// a left padding parameter. + memory::dims get_padding_l() const { return query_dims(query::padding_l); } + + /// Returns a right padding. + /// @returns A right padding. + /// @returns An empty #dnnl::memory::dims if the primitive does not have + /// a right padding parameter. + memory::dims get_padding_r() const { return query_dims(query::padding_r); } + + /// Returns an epsilon. + /// @returns An epsilon. + /// @returns Zero if the primitive does not have an epsilon parameter. + float get_epsilon() const { return query_f32(query::epsilon_f32); } + + /// Returns flags. + /// @tparam T Flags enumeration type. + /// @returns Flags. + /// @returns Zero if the primitive does not have a flags parameter. + template + T get_flags() const { + unsigned res; + dnnl_status_t status + = dnnl_primitive_desc_query(get(), dnnl_query_flags, 0, &res); + return static_cast(status == dnnl_success ? res : 0x0U); + } + + /// Returns an algorithm kind. + /// @returns An algorithm kind. + /// @returns #dnnl::algorithm::undef if the primitive does not have an + /// algorithm parameter. + dnnl::algorithm get_algorithm() const { return query_alg(query::alg_kind); } + + /// Returns an alpha. + /// @returns An alpha. + /// @returns Zero if the primitive does not have an alpha parameter. + float get_alpha() const { return query_f32(query::alpha_f32); } + + /// Returns a beta. + /// @returns A beta. + /// @returns Zero if the primitive does not have a beta parameter. + float get_beta() const { return query_f32(query::beta_f32); } + + /// Returns an axis. + /// @returns An axis. + /// @returns A negative number if the primitive does not have an axis + /// parameter. + int get_axis() const { + int res; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl_query_axis_s32, 0, &res); + return status == dnnl_success ? res : -1; + } + + /// Returns an LRN local size parameter. + /// @returns An LRN local size parameter. + /// @returns Zero if the primitive does not have an LRN local size + /// parameter. + memory::dim get_local_size() const { + return query_s64(query::local_size_s64); + } + + /// Returns an LRN K parameter. + /// @returns An LRN K parameter. + /// @returns Zero if the primitive does not have an LRN K parameter. + float get_k() const { return query_f32(query::k_f32); } + + /// Returns a reduction P parameter. + /// @returns A reduction P parameter. + /// @returns Zero if the primitive does not have a reduction P parameter. + float get_p() const { return query_f32(query::p_f32); } + + /// Returns a resampling factors parameters. + /// @returns A vector of factors. + /// @returns An empty vector if the primitive does not have a resampling + /// factors parameter. + std::vector get_factors() const { + float *factors; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl_query_factors, 0, &factors); + + const bool is_backward = get_prop_kind() != prop_kind::forward_training + && get_prop_kind() != prop_kind::forward_inference; + const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(), + is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0); + + int ndims; + error::wrap_c_api( + dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims), + "could not query ndims from a memory descriptor"); + + return status == dnnl_success + ? std::vector(factors, factors + (ndims - 2)) + : std::vector {}; + } + + /// Returns an RNN cell kind parameter. + /// @returns An RNN cell kind parameter. + /// @returns #dnnl::algorithm::undef if the primitive does not have an + /// RNN cell kind parameter. + dnnl::algorithm get_cell_kind() const { + return query_alg(query::cell_kind); + } + + /// Returns an RNN direction parameter. + /// @returns An RNN direction parameter. + /// @returns #dnnl::rnn_direction::undef if the primitive does not have + /// an RNN direction parameter. + dnnl::rnn_direction get_direction() const { + dnnl_rnn_direction_t direction; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl_query_direction, 0, &direction); + return status == dnnl_success + ? static_cast(direction) + : dnnl::rnn_direction::undef; + } + + /// Returns an RNN activation kind parameter. + /// @returns An RNN activation kind parameter. + /// @returns #dnnl::algorithm::undef if the primitive does not have an + /// RNN activation kind parameter. + dnnl::algorithm get_activation_kind() const { + return query_alg(query::activation_kind); + } + + /// Returns a pooling kernel parameter. + /// @returns A pooling kernel parameter. + /// @returns An empty #dnnl::memory::dims if the primitive does not have + /// a pooling kernel parameter. + memory::dims get_kernel() const { return query_dims(query::kernel); } + + /// Returns a group size parameter. + /// @returns A group size parameter. + /// @returns Zero if the primitive does not have a group size + /// parameter. + memory::dim get_group_size() const { + return query_s64(query::group_size_s64); + } + + /// Returns a propagation kind. + /// @returns A propagation kind. + /// @returns #dnnl::prop_kind::undef if the primitive does not have + /// a propagation parameter. + dnnl::prop_kind get_prop_kind() const { + dnnl_prop_kind_t prop_kind; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl_query_prop_kind, 0, &prop_kind); + return status == dnnl_success ? static_cast(prop_kind) + : dnnl::prop_kind::undef; + } + + /// Returns a memory descriptor. + /// + /// @note + /// There are also convenience methods + /// #dnnl::primitive_desc_base::src_desc(), + /// #dnnl::primitive_desc_base::dst_desc(), and others. + /// + /// @param what The kind of parameter to query; can be + /// #dnnl::query::src_md, #dnnl::query::dst_md, etc. + /// @param idx Index of the parameter. For example, convolution bias can + /// be queried with what = #dnnl::query::weights_md and idx = 1. + /// @returns The requested memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// parameter of the specified kind or index. + memory::desc query_md(query what, int idx = 0) const { + std::vector valid_q {query::src_md, query::diff_src_md, + query::weights_md, query::diff_weights_md, query::dst_md, + query::diff_dst_md, query::workspace_md, query::scratchpad_md, + query::exec_arg_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), + [=](query q) { return what == q; })) + DNNL_THROW_ERROR(dnnl_invalid_arguments, + "memory descriptor query is invalid"); + + const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md( + get(), dnnl::convert_to_c(what), idx); + if (!cdesc) return memory::desc(); + + dnnl_memory_desc_t cloned_md = nullptr; + error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc), + "could not clone a memory descriptor"); + + return memory::desc(cloned_md); + } + + /// Returns a source memory descriptor. + /// @param idx Source index. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter with index @p idx. + memory::desc src_desc(int idx) const { + return query_md(query::src_md, idx); + } + + /// Returns a destination memory descriptor. + /// @param idx Destination index. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter with index @p idx. + memory::desc dst_desc(int idx) const { + return query_md(query::dst_md, idx); + } + + /// Returns a weights memory descriptor. + /// @param idx Weights index. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter with index @p idx. + memory::desc weights_desc(int idx) const { + return query_md(query::weights_md, idx); + } + + /// Returns a diff source memory descriptor. + /// @param idx Diff source index. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source parameter with index @p idx. + memory::desc diff_src_desc(int idx) const { + return query_md(query::diff_src_md, idx); + } + + /// Returns a diff destination memory descriptor. + /// @param idx Diff destination index. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter with index @p idx. + memory::desc diff_dst_desc(int idx) const { + return query_md(query::diff_dst_md, idx); + } + + /// Returns a diff weights memory descriptor. + /// @param idx Diff weights index. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter with index @p idx. + memory::desc diff_weights_desc(int idx) const { + return query_md(query::diff_weights_md, idx); + } + + // Separate versions without the index argument for documentation + // purposes. + + /// Returns a source memory descriptor. + /// @returns Source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source parameter. + memory::desc src_desc() const { return src_desc(0); } + + /// Returns a destination memory descriptor. + /// @returns Destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination parameter. + memory::desc dst_desc() const { return dst_desc(0); } + + /// Returns a weights memory descriptor. + /// @returns Weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// weights parameter. + memory::desc weights_desc() const { return weights_desc(0); } + + /// Returns a diff source memory descriptor. + /// @returns Diff source memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source memory with. + memory::desc diff_src_desc() const { return diff_src_desc(0); } + + /// Returns a diff destination memory descriptor. + /// @returns Diff destination memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination parameter. + memory::desc diff_dst_desc() const { return diff_dst_desc(0); } + + /// Returns a diff weights memory descriptor. + /// @returns Diff weights memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff weights parameter. + memory::desc diff_weights_desc() const { return diff_weights_desc(0); } + + /// Returns the workspace memory descriptor. + /// @returns Workspace memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// workspace parameter. + memory::desc workspace_desc() const { + return query_md(query::workspace_md, 0); + } + + /// Returns the scratchpad memory descriptor. + /// @returns scratchpad memory descriptor. + /// @returns A zero memory descriptor if the primitive does not require + /// scratchpad parameter. + /// @sa @ref dev_guide_attributes_scratchpad + memory::desc scratchpad_desc() const { + return query_md(query::scratchpad_md, 0); + } + + /// Returns the engine on which the scratchpad memory is located. + /// @returns The engine on which the scratchpad memory is located. + engine scratchpad_engine() const { + dnnl_engine_t c_engine; + error::wrap_c_api(dnnl_primitive_desc_query(get(), + dnnl::convert_to_c(query::scratchpad_engine), + 0, &c_engine), + "could not retrieve scratchpad engine from a primitive " + "descriptor"); + return engine(c_engine, true); + } + + /// Returns the primitive attributes. + /// @returns The primitive attributes. + primitive_attr get_primitive_attr() const { + const_dnnl_primitive_attr_t const_c_attr; + error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr), + "could not get attributes from a primitive descriptor"); + dnnl_primitive_attr_t c_attr; + error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr), + "could not clone primitive attributes"); + return primitive_attr(c_attr); + } + + /// Returns the kind of the primitive descriptor. + /// @returns The kind of the primitive descriptor. + dnnl::primitive::kind get_kind() const { + dnnl_primitive_kind_t kind; + error::wrap_c_api(dnnl_primitive_desc_query(get(), + dnnl_query_primitive_kind, 0, (void *)&kind), + "could not get primitive kind from a primitive descriptor"); + return static_cast(kind); + } + + /// Returns the cache blob ID of the primitive descriptor. + /// @returns The cache blob ID of the primitive descriptor. + std::vector get_cache_blob_id() const { + dnnl_dim_t count; + const uint8_t *c_id; + error::wrap_c_api( + dnnl_primitive_desc_query(get(), + dnnl::convert_to_c(query::cache_blob_id_size_s64), 0, + (void *)&count), + "could not get size of cache blob ID from a primitive " + "descriptor"); + error::wrap_c_api(dnnl_primitive_desc_query(get(), + dnnl::convert_to_c(query::cache_blob_id), 0, + (void **)&c_id), + "could not get cache blob ID from a primitive descriptor"); + std::vector id(c_id, c_id + count); + return id; + } + +protected: + /// Returns a float value. + /// @param what The value to query. + /// @returns The result of the query. + /// @returns Zero if the primitive doesn't support the query. + float query_f32(query what) const { + float res; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl::convert_to_c(what), 0, &res); + return status == dnnl_success ? res : 0.0f; + } + + /// Returns an #dnnl::algorithm value. + /// @param what The value to query. + /// @returns The result of the query. + /// @returns #dnnl::algorithm::undef if the primitive doesn't support + /// the query. + algorithm query_alg(query what) const { + dnnl_alg_kind_t res; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl::convert_to_c(what), 0, &res); + return status == dnnl_success ? static_cast(res) + : algorithm::undef; + } + + /// Returns a memory::dims value. + /// @param what The value to query. + /// @returns The result of the query. + /// @returns An empty #dnnl::memory::dims if the primitive doesn't support + /// the query. + memory::dims query_dims(query what) const { + const bool is_backward = get_prop_kind() != prop_kind::forward_training + && get_prop_kind() != prop_kind::forward_inference; + const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(), + is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0); + + int nspatial_dims = 0; + if (md) { + int ndims; + error::wrap_c_api( + dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims), + "could not query ndims from a memory descriptor"); + nspatial_dims = ndims - 2; + } + + dnnl_dims_t *c_dims; + dnnl_status_t status = dnnl_primitive_desc_query( + get(), dnnl::convert_to_c(what), 0, &c_dims); + return status == dnnl_success + ? memory::dims(*c_dims, *c_dims + nspatial_dims) + : memory::dims {}; + } + + /// Returns an #dnnl::engine value. + /// @param what The value to query. + /// @returns The result of the query. + /// @returns A weak handle to the engine that the primitive descriptor was + /// created with. + engine query_engine(query what) const { + dnnl_engine_t c_engine; + error::wrap_c_api(dnnl_primitive_desc_query(get(), + dnnl::convert_to_c(what), 0, &c_engine), + "could not get an engine from a primitive_desc"); + return engine(c_engine, true); + } + + /// Resets the value of the handle to a clone of a C API primitive + /// descriptor. + /// @param pd A C API primitive descriptor to clone. + void reset_with_clone(const_dnnl_primitive_desc_t pd) { + dnnl_primitive_desc_t new_pd; + error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd), + "could not clone a primitive descriptor"); + reset(new_pd); + } + + /// Constructs a primitive descriptor base object from a clone of a C API + /// primitive descriptor after verifying that it is what the caller + /// expects. + /// + /// @note + /// The @p prim_kind should map to a primitive that does not have + /// different values of propagation kind (e.g. #dnnl::binary). + /// @note + /// Primitive descriptor base constructed this way does not support + /// next_impl() (will throw). + /// + /// @param pd C API primitive descriptor to clone. + /// @param prim_kind Expected primitive kind. + primitive_desc_base( + dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind) + : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {} + + /// Constructs a primitive descriptor base object from a clone of a C API + /// primitive descriptor after verifying that it is what the caller + /// expects. + /// + /// @note + /// Primitive descriptor base constructed this way does not support + /// next_impl() (will throw). + /// + /// @param pd C API primitive descriptor to clone. + /// @param prim_kind Expected primitive kind. + /// @param aprop_kind Expected propagation kind. + primitive_desc_base(dnnl_primitive_desc_t pd, + dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind) + : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {} + + /// Constructs a primitive descriptor base object from a clone of a C API + /// primitive descriptor after verifying that it is what the caller + /// expects. + /// + /// @note + /// Primitive descriptor base constructed this way does not support + /// next_impl() (will throw). + /// + /// @param pd C API primitive descriptor to clone. + /// @param prim_kind Expected primitive kind. + /// @param prop_kind1 Expected propagation kind (option 1). + /// @param prop_kind2 Expected propagation kind (option 2). This value is + /// checked if the check with @p prop_kind1 fails. + primitive_desc_base(dnnl_primitive_desc_t pd, + dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1, + dnnl::prop_kind prop_kind2) { + // It is OK to pass an empty primitive descriptor + if (pd == nullptr) return; + + dnnl_status_t rc; + + dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind); + dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1); + dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2); + + // Check that primitive kind matches + dnnl_primitive_kind_t pd_kind; + rc = dnnl_primitive_desc_query( + pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind); + error::wrap_c_api( + rc, "could not get primitive kind from a primitive descriptor"); + if (pd_kind != c_prim_kind) + DNNL_THROW_ERROR(dnnl_invalid_arguments, + "primitive descriptor operation kind mismatch"); + + // Check that propagation kind matches + dnnl_prop_kind_t pd_prop_kind; + rc = dnnl_primitive_desc_query( + pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind); + + // Something went wrong + if (rc != dnnl_success && rc != dnnl_unimplemented) + DNNL_THROW_ERROR(dnnl_invalid_arguments, + "could not get propagation kind from the primitive " + "descriptor"); + + // Everything is fine + if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef) + || (rc == dnnl_success + && (pd_prop_kind == c_prop_kind1 + || pd_prop_kind == c_prop_kind2))) { + reset_with_clone(pd); + return; + } + + // We could get the propagation kind but there is a mismatch + DNNL_THROW_ERROR(dnnl_invalid_arguments, + "primitive descriptor propagation kind mismatch"); + } + + /// Returns a constant reference to a static instance of default constructed + /// primitive attributes + static const primitive_attr &default_attr() { + static const primitive_attr attr; + return attr; + } + + const_dnnl_memory_desc_t optional_arg(const memory::desc *md) { + return md ? md->get() : nullptr; + } + + const dnnl_dim_t *optional_arg(const memory::dims *dims) { + return dims ? dims->data() : nullptr; + } + + const float *optional_arg(const std::vector *arg) { + return arg ? arg->data() : nullptr; + } + + using base = primitive_desc_base; +}; + +/// @} dnnl_api_primitives_common + +/// @addtogroup dnnl_api_reorder Reorder +/// +/// A primitive to copy data between two memory objects. This primitive is +/// typically used to change the way the data is laid out in memory. +/// +/// @sa @ref dev_guide_reorder in developer guide +/// +/// @{ + +/// Reorder primitive. +struct reorder : public primitive { + /// Primitive descriptor for a reorder primitive. + struct primitive_desc : public primitive_desc_base { + using primitive_desc_base::primitive_desc_base; + + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for reorder primitive. + /// + /// @note + /// If @p allow_empty is true, the constructor does not throw if a + /// primitive descriptor cannot be created. + /// + /// @param src_engine Engine on which the source memory object will be + /// located. + /// @param src_md Source memory descriptor. + /// @param dst_engine Engine on which the destination memory object + /// will be located. + /// @param dst_md Destination memory descriptor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is allowed + /// to fail without throwing an exception. In this case an empty + /// object will be produced. This flag is optional and defaults to + /// false. + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + dnnl_primitive_desc_t result; + dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result, + src_md.get(), src_engine.get(), dst_md.get(), + dst_engine.get(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the reorder primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for reorder primitive. + /// + /// @param src Source memory object. It is used to obtain the source + /// memory descriptor and engine. + /// @param dst Destination memory object. It is used to obtain the + /// destination memory descriptor and engine. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is allowed + /// to fail without throwing an exception. In this case an empty + /// object will be produced. This flag is optional and defaults to + /// false. + primitive_desc(const memory &src, const memory &dst, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + dnnl_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result, + src_md.get(), src.get_engine().get(), dst_md.get(), + dst.get_engine().get(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the reorder primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for reorder primitive from a C + /// API primitive descriptor which must have a matching kind. + /// + /// @param pd C API primitive descriptor for reorder primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : primitive_desc_base(pd, dnnl::primitive::kind::reorder) {} + + /// Returns the engine on which the source memory is allocated. + /// @returns The engine on which the source memory is allocated. + engine get_src_engine() const { + return query_engine(dnnl::query::reorder_src_engine); + } + + /// Returns the engine on which the destination memory is allocated. + /// @returns The engine on which the destination memory is allocated. + engine get_dst_engine() const { + return query_engine(dnnl::query::reorder_dst_engine); + } + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + }; + + /// Default constructor. Produces an empty object. + reorder() = default; + + /// Constructs a reorder primitive. + /// @param pd Primitive descriptor for reorder primitive. + reorder(const primitive_desc &pd) : primitive(pd.get()) {} + + /// Constructs a reorder primitive from a cache blob. + /// @param pd Primitive descriptor for reorder primitive. + /// @param cache_blob Cache blob. + reorder(const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd.get(), cache_blob) {} + + /// Constructs a reorder primitive that would reorder data between memory + /// objects having the same memory descriptors as memory objects @p src and + /// @p dst. + /// + /// @param src Source memory object. + /// @param dst Destination memory object. + /// @param attr Primitive attributes to use (optional). + reorder(const memory &src, const memory &dst, + const primitive_attr &attr = primitive_attr()) + : primitive(primitive_desc(src, dst, attr).get()) {} + + using primitive::execute; + + /// Executes the reorder primitive. + /// + /// @param astream Stream object. The stream must belong to the same engine + /// as the primitive. + /// @param src Source memory object. + /// @param dst Destination memory object. + void execute(const stream &astream, memory &src, memory &dst) const { + primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}}); + } +}; + +/// @} dnnl_api_reorder + +/// @addtogroup dnnl_api_concat Concat +/// +/// A primitive to concatenate data by arbitrary dimension. +/// +/// @sa @ref dev_guide_concat in developer guide +/// +/// @{ + +/// @cond DO_NOT_DOCUMENT_THIS +inline std::vector convert_to_c( + const std::vector &mds) { + std::vector c_mds; + c_mds.reserve(mds.size()); + for (const auto &md : mds) + c_mds.push_back(md.get()); + return c_mds; +} +/// @endcond + +/// Tensor concatenation (concat) primitive. +struct concat : public primitive { + /// Primitive descriptor for a concat primitive. + struct primitive_desc : public primitive_desc_base { + using primitive_desc_base::primitive_desc_base; + + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an out-of-place concatenation + /// primitive. + /// + /// @param aengine Engine to perform the operation on. + /// @param dst Destination memory descriptor. + /// @param concat_dimension Source tensors will be concatenated over + /// dimension with this index. Note that order of dimensions does + /// not depend on memory format. + /// @param srcs Vector of source memory descriptors. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &dst, + int concat_dimension, const std::vector &srcs, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + auto c_srcs = convert_to_c(srcs); + + dnnl_primitive_desc_t result; + dnnl_status_t status = dnnl_concat_primitive_desc_create(&result, + aengine.get(), dst.get(), (int)c_srcs.size(), + concat_dimension, c_srcs.data(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the concat primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for an out-of-place concatenation + /// primitive. + /// + /// This version derives the destination memory descriptor + /// automatically. + /// + /// @param aengine Engine to perform the operation on. + /// @param concat_dimension Source tensors will be concatenated over + /// dimension with this index. Note that order of dimensions does + /// not depend on memory format. + /// @param srcs Vector of source memory descriptors. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, int concat_dimension, + const std::vector &srcs, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + auto c_api_srcs = convert_to_c(srcs); + + dnnl_primitive_desc_t result; + dnnl_status_t status = dnnl_concat_primitive_desc_create(&result, + aengine.get(), nullptr, (int)c_api_srcs.size(), + concat_dimension, c_api_srcs.data(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the concat primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for concat primitive from a C + /// API primitive descriptor which must have a matching kind. + /// + /// @param pd C API primitive descriptor for concat primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : primitive_desc_base(pd, dnnl::primitive::kind::concat) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc(int)const + memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + }; + + /// Default constructor. Produces an empty object. + concat() = default; + + /// Constructs a concatenation primitive. + /// @param pd Primitive descriptor for concatenation primitive. + concat(const primitive_desc &pd) : primitive(pd.get()) {} + + /// Constructs a concatenation primitive from a cache blob. + /// @param pd Primitive descriptor for concatenation primitive. + /// @param cache_blob Cache blob. + concat(const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd.get(), cache_blob) {} +}; + +/// @} dnnl_api_concat + +/// @addtogroup dnnl_api_sum Sum +/// +/// A primitive to sum multiple tensors. +/// +/// @sa @ref dev_guide_sum in developer guide +/// +/// @{ + +/// Out-of-place summation (sum) primitive. +struct sum : public primitive { + /// Primitive descriptor for a sum primitive. + struct primitive_desc : public primitive_desc_base { + using primitive_desc_base::primitive_desc_base; + + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a sum primitive. + /// + /// @param aengine Engine to perform the operation on. + /// @param dst Destination memory descriptor. + /// @param scales Vector of scales to multiply data in each source + /// memory by. + /// @param srcs Vector of source memory descriptors. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &dst, + const std::vector &scales, + const std::vector &srcs, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + validate_container_size(scales, + "counts of scales and sources are not equal", + (int)srcs.size(), (int)srcs.size()); + + auto c_api_srcs = convert_to_c(srcs); + + dnnl_primitive_desc_t result; + dnnl_status_t status = dnnl_sum_primitive_desc_create(&result, + aengine.get(), dst.get(), (int)c_api_srcs.size(), + scales.data(), c_api_srcs.data(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the sum primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for a sum primitive. + /// + /// This version derives the destination memory descriptor + /// automatically. + /// + /// @param aengine Engine on which to perform the operation. + /// @param scales Vector of scales by which to multiply data in each + /// source memory object. + /// @param srcs Vector of source memory descriptors. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const std::vector &scales, + const std::vector &srcs, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + validate_container_size(scales, + "counts of scales and sources are not equal", + (int)srcs.size(), (int)srcs.size()); + + auto c_api_srcs = convert_to_c(srcs); + dnnl_primitive_desc_t result; + dnnl_status_t status = dnnl_sum_primitive_desc_create(&result, + aengine.get(), nullptr, (int)c_api_srcs.size(), + scales.data(), c_api_srcs.data(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the sum primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for sum primitive from a C API + /// primitive descriptor which must have a matching kind. + /// + /// @param pd C API primitive descriptor for sum primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : primitive_desc_base(pd, dnnl::primitive::kind::sum) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc(int)const + memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + }; + + /// Default constructor. Produces an empty object. + sum() = default; + + /// Constructs a sum primitive. + /// @param pd Primitive descriptor for sum primitive. + sum(const primitive_desc &pd) : primitive(pd.get()) {} + + /// Constructs a sum primitive from a cache blob. + /// @param pd Primitive descriptor for sum primitive. + /// @param cache_blob Cache blob. + sum(const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd.get(), cache_blob) {} +}; + +/// @} dnnl_api_sum + +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// A base class for descriptors of all primitives that support iteration +/// over multiple implementations. +struct primitive_desc : public primitive_desc_base { + using primitive_desc_base::primitive_desc_base; + + primitive_desc() = default; + + /// Changes the primitive descriptor to point to the next available + /// implementation. + /// + /// @returns @c true on success and @c false if the last available + /// implementation has already been reached. In the latter case, the + /// primitive descriptor itself is kept unchanged. + bool next_impl() { + dnnl_status_t status = dnnl_primitive_desc_next_impl(get()); + if (status == dnnl_last_impl_reached) return false; + error::wrap_c_api(status, "last available implementation is reached"); + return true; + } +}; + +/// @} dnnl_api_primitives_common + +/// @addtogroup dnnl_api_convolution Convolution +/// +/// A primitive to perform 1D, 2D or 3D convolution. Supported variants are +/// forward propagation, backward propagation, and weights gradient with or +/// without bias. +/// +/// @sa @ref dev_guide_convolution in developer guide +/// +/// @{ + +/// Convolution forward propagation primitive. +struct convolution_forward : public primitive { + /// Primitive descriptor for a convolution forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a convolution forward + /// propagation primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param bias_desc Bias memory descriptor. Passing zero memory + /// descriptor disables the bias term. + /// @param dst_desc Destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &bias_desc, + const memory::desc &dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, &bias_desc, dst_desc, strides, nullptr, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution forward + /// propagation primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &dst_desc, + const memory::dims &strides, const memory::dims &padding_l, + const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, nullptr, dst_desc, strides, nullptr, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution forward + /// propagation primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param bias_desc Bias memory descriptor. Passing zero memory + /// descriptor disables the bias term. + /// @param dst_desc Destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &bias_desc, + const memory::desc &dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, &bias_desc, dst_desc, strides, &dilates, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution forward + /// propagation primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &dst_desc, + const memory::dims &strides, const memory::dims &dilates, + const memory::dims &padding_l, const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, nullptr, dst_desc, strides, &dilates, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a convolution forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// Returns the bias memory descriptor. + /// @returns The bias memory descriptor. + /// @returns A zero memory descriptor of the primitive does not have a + /// bias parameter. + memory::desc bias_desc() const { return base::weights_desc(1); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + + private: + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc *bias_desc, + const memory::desc &dst_desc, const memory::dims &strides, + const memory::dims *dilates, const memory::dims &padding_l, + const memory::dims &padding_r, const primitive_attr &attr, + bool allow_empty) { + + memory::validate_dims(strides, src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, src_desc.get_ndims() - 2); + + if (dilates) + memory::validate_dims(*dilates, src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_convolution_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), src_desc.get(), + weights_desc.get(), optional_arg(bias_desc), + dst_desc.get(), &strides[0], optional_arg(dilates), + &padding_l[0], &padding_r[0], attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the convolution forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + convolution_forward() = default; + + /// Constructs a convolution forward propagation primitive. + /// @param pd Primitive descriptor for a convolution forward propagation + /// primitive. + convolution_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a convolution forward propagation primitive from a cache + /// blob. + /// @param pd Primitive descriptor for a convolution forward propagation + /// primitive. + /// @param cache_blob Cache blob. + convolution_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Convolution backward propagation primitive. +struct convolution_backward_data : public primitive { + /// Primitive descriptor for a convolution backward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a convolution backward + /// propagation primitive. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param diff_src_desc Diff source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a convolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, + diff_dst_desc, strides, nullptr, padding_l, padding_r, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution backward + /// propagation primitive. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param diff_src_desc Diff source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a convolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, + diff_dst_desc, strides, &dilates, padding_l, padding_r, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a convolution backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + + private: + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims *dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + memory::validate_dims(strides, diff_src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2); + + if (dilates) + memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_convolution_backward_data_primitive_desc_create(&pd, + aengine.get(), convert_to_c(aalgorithm), + diff_src_desc.get(), weights_desc.get(), + diff_dst_desc.get(), &strides[0], + optional_arg(dilates), &padding_l[0], &padding_r[0], + hint_fwd_pd.get(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the convolution backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + convolution_backward_data() = default; + + /// Constructs a convolution backward propagation primitive. + /// @param pd Primitive descriptor for a convolution backward propagation + /// primitive. + convolution_backward_data(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a convolution backward propagation primitive from a cache + /// blob. + /// @param pd Primitive descriptor for a convolution backward propagation + /// primitive. + /// @param cache_blob Cache blob. + convolution_backward_data( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Convolution weights gradient primitive. +struct convolution_backward_weights : public primitive { + /// Primitive descriptor for a convolution weights gradient primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a convolution weights gradient + /// primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_bias_desc Diff bias memory descriptor. Passing zero + /// memory descriptor disables the bias term. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a convolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l, + padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution weights gradient + /// primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a convolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + nullptr, diff_dst_desc, strides, nullptr, padding_l, + padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution weights + /// gradient primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_bias_desc Diff bias memory descriptor. Passing zero + /// memory descriptor disables the bias term. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a convolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + &diff_bias_desc, diff_dst_desc, strides, &dilates, + padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution weights + /// gradient primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Convolution algorithm. Possible values are + /// #dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd, and + /// #dnnl::algorithm::convolution_auto. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a convolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + nullptr, diff_dst_desc, strides, &dilates, padding_l, + padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a convolution weights gradient + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a convolution weights + /// gradient primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution, + dnnl::prop_kind::backward_weights) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const + memory::desc diff_weights_desc() const { + return base::diff_weights_desc(0); + } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// Returns the diff bias memory descriptor. + /// @returns The diff bias memory descriptor. + /// @returns A zero memory descriptor of the primitive does not have a + /// diff bias parameter. + memory::desc diff_bias_desc() const { + return base::diff_weights_desc(1); + } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + + private: + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc *diff_bias_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims *dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const convolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + memory::validate_dims(strides, src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, src_desc.get_ndims() - 2); + + if (dilates) + memory::validate_dims(*dilates, src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_convolution_backward_weights_primitive_desc_create( + &pd, aengine.get(), convert_to_c(aalgorithm), + src_desc.get(), diff_weights_desc.get(), + optional_arg(diff_bias_desc), diff_dst_desc.get(), + &strides[0], optional_arg(dilates), &padding_l[0], + &padding_r[0], hint_fwd_pd.get(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the convolution weights update primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + convolution_backward_weights() = default; + + /// Constructs a convolution weights gradient primitive. + /// @param pd Primitive descriptor for a convolution weights gradient + /// primitive. + convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a convolution weights gradient primitive from a cache blob. + /// @param pd Primitive descriptor for a convolution weights gradient + /// primitive. + /// @param cache_blob Cache blob. + convolution_backward_weights( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_convolution +// +/// @addtogroup dnnl_api_deconvolution Deconvolution +/// +/// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are +/// forward propagation, backward propagation, and weights gradient with or +/// without bias. +/// +/// @{ + +/// Deconvolution forward propagation primitive. +struct deconvolution_forward : public primitive { + /// Primitive descriptor for a deconvolution forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a deconvolution forward + /// propagation primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Deconvolution algorithm: + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param bias_desc Bias memory descriptor. Passing zero memory + /// descriptor disables the bias term. + /// @param dst_desc Destination memory descriptor. + /// @param strides Vector of strides for spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &bias_desc, + const memory::desc &dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, &bias_desc, dst_desc, strides, nullptr, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution forward + /// propagation primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Deconvolution algorithm: + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param strides Vector of strides for spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &dst_desc, + const memory::dims &strides, const memory::dims &padding_l, + const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, nullptr, dst_desc, strides, nullptr, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution forward + /// propagation primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Deconvolution algorithm: + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param bias_desc Bias memory descriptor. Passing zero memory + /// descriptor disables the bias term. + /// @param dst_desc Destination memory descriptor. + /// @param strides Vector of strides for spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &bias_desc, + const memory::desc &dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, &bias_desc, dst_desc, strides, &dilates, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution forward + /// propagation primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Deconvolution algorithm: + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param strides Vector of strides for spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &dst_desc, + const memory::dims &strides, const memory::dims &dilates, + const memory::dims &padding_l, const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + weights_desc, nullptr, dst_desc, strides, &dilates, + padding_l, padding_r, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a deconvolution forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const + memory::desc bias_desc() const { return base::weights_desc(1); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + + private: + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc *bias_desc, + const memory::desc &dst_desc, const memory::dims &strides, + const memory::dims *dilates, const memory::dims &padding_l, + const memory::dims &padding_r, const primitive_attr &attr, + bool allow_empty) { + + memory::validate_dims(strides, src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, src_desc.get_ndims() - 2); + + if (dilates) + memory::validate_dims(*dilates, src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_deconvolution_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), src_desc.get(), + weights_desc.get(), optional_arg(bias_desc), + dst_desc.get(), &strides[0], optional_arg(dilates), + &padding_l[0], &padding_r[0], attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the deconvolution forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + deconvolution_forward() = default; + + /// Constructs a deconvolution forward propagation primitive. + /// @param pd Primitive descriptor for a deconvolution forward propagation + /// primitive. + deconvolution_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a deconvolution forward propagation primitive from a cache + /// blob. + /// @param pd Primitive descriptor for a deconvolution forward propagation + /// primitive. + /// @param cache_blob Cache blob. + deconvolution_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Deconvolution backward propagation primitive. +struct deconvolution_backward_data : public primitive { + /// Primitive descriptor for a deconvolution backward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a deconvolution backward + /// propagation primitive. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Deconvolution algorithm + /// (#dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd). + /// @param diff_src_desc Diff source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a deconvolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, + diff_dst_desc, strides, nullptr, padding_l, padding_r, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution backward + /// propagation primitive. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Deconvolution algorithm + /// (#dnnl::algorithm::convolution_direct, + /// #dnnl::algorithm::convolution_winograd). + /// @param diff_src_desc Diff source memory descriptor. + /// @param weights_desc Weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a deconvolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc, + diff_dst_desc, strides, &dilates, padding_l, padding_r, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a deconvolution backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + + private: + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims *dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + memory::validate_dims(strides, diff_src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2); + + if (dilates) + memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_deconvolution_backward_data_primitive_desc_create( + &pd, aengine.get(), convert_to_c(aalgorithm), + diff_src_desc.get(), weights_desc.get(), + diff_dst_desc.get(), &strides[0], + optional_arg(dilates), &padding_l[0], &padding_r[0], + hint_fwd_pd.get(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the deconvolution backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + deconvolution_backward_data() = default; + + /// Constructs a deconvolution backward propagation primitive. + /// @param pd Primitive descriptor for a deconvolution backward propagation + /// primitive. + deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a deconvolution backward propagation primitive from a cache + /// blob. + /// @param pd Primitive descriptor for a deconvolution backward propagation + /// primitive. + /// @param cache_blob Cache blob. + deconvolution_backward_data( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Deconvolution weights gradient primitive. +struct deconvolution_backward_weights : public primitive { + /// Primitive descriptor for a deconvolution weights gradient primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a deconvolution weights + /// gradient primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Deconvolution algorithm. Possible values are + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_bias_desc Diff bias memory descriptor. Passing zero + /// memory descriptor disables the bias term. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a deconvolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l, + padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution weights + /// gradient primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p padding_l, and @p padding_r contain values + /// for spatial dimensions only and hence must have the same number of + /// elements as there are spatial dimensions. The order of values is + /// the same as in the tensor: depth (for 3D tensors), height (for 3D + /// and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Deconvolution algorithm. Possible values are + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a deconvolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &padding_l, const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + nullptr, diff_dst_desc, strides, nullptr, padding_l, + padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution weights + /// gradient primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Deconvolution algorithm. Possible values are + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_bias_desc Diff bias memory descriptor. Passing zero + /// memory descriptor disables the bias term. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a deconvolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + &diff_bias_desc, diff_dst_desc, strides, &dilates, + padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution weights + /// gradient primitive without bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r + /// contain values for spatial dimensions only and hence must have the + /// same number of elements as there are spatial dimensions. The order + /// of values is the same as in the tensor: depth (for 3D tensors), + /// height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Deconvolution algorithm. Possible values are + /// #dnnl::algorithm::deconvolution_direct, and + /// #dnnl::algorithm::deconvolution_winograd. + /// @param src_desc Source memory descriptor. + /// @param diff_weights_desc Diff weights memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Strides for each spatial dimension. + /// @param dilates Dilations for each spatial dimension. A zero value + /// means no dilation in the corresponding dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a deconvolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc, + nullptr, diff_dst_desc, strides, &dilates, padding_l, + padding_r, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a deconvolution weights + /// gradient primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a deconvolution weights + /// gradient primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution, + dnnl::prop_kind::backward_weights) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const + memory::desc diff_weights_desc() const { + return base::diff_weights_desc(0); + } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return base::diff_weights_desc(1); + } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + + private: + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc *diff_bias_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims *dilates, const memory::dims &padding_l, + const memory::dims &padding_r, + const deconvolution_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + memory::validate_dims(strides, src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, src_desc.get_ndims() - 2); + + if (dilates) + memory::validate_dims(*dilates, src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_deconvolution_backward_weights_primitive_desc_create( + &pd, aengine.get(), convert_to_c(aalgorithm), + src_desc.get(), diff_weights_desc.get(), + optional_arg(diff_bias_desc), diff_dst_desc.get(), + &strides[0], optional_arg(dilates), &padding_l[0], + &padding_r[0], hint_fwd_pd.get(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the deconvolution weights update primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + deconvolution_backward_weights() = default; + + /// Constructs a deconvolution weights gradient primitive. + /// @param pd Primitive descriptor for a deconvolution weights gradient + /// primitive. + deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a deconvolution weights gradient primitive from a cache + /// blob. + /// @param pd Primitive descriptor for a deconvolution weights gradient + /// primitive. + /// @param cache_blob Cache blob. + deconvolution_backward_weights( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_deconvolution + +/// @addtogroup dnnl_api_lrn LRN +/// +/// A primitive to perform local response normalization (LRN) across or within +/// channels. +/// +/// @sa @ref dev_guide_lrn in developer guide +/// +/// @{ + +/// Local response normalization (LRN) forward propagation primitive. +struct lrn_forward : public primitive { + /// Primitive descriptor for an LRN forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an LRN forward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm LRN algorithm kind: either + /// #dnnl::algorithm::lrn_across_channels, or + /// #dnnl::algorithm::lrn_within_channel. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param local_size Regularization local size. + /// @param alpha The alpha regularization parameter. + /// @param beta The beta regularization parameter. + /// @param k The k regularization parameter. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, memory::dim local_size, + float alpha, float beta, float k, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_lrn_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(), + local_size, alpha, beta, k, attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the lrn forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for an LRN forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for an LRN forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_alpha()const + float get_alpha() const { return base::get_alpha(); } + + /// @copydoc dnnl::primitive_desc_base::get_beta()const + float get_beta() const { return base::get_beta(); } + + /// @copydoc dnnl::primitive_desc_base::get_local_size()const + memory::dim get_local_size() const { return base::get_local_size(); } + + /// @copydoc dnnl::primitive_desc_base::get_k()const + float get_k() const { return base::get_k(); } + }; + + /// Default constructor. Produces an empty object. + lrn_forward() = default; + + /// Constructs an LRN forward propagation primitive. + /// @param pd Primitive descriptor for an LRN forward propagation + /// primitive. + lrn_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LRN forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LRN forward propagation + /// primitive. + /// @param cache_blob Cache blob. + lrn_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Local response normalization (LRN) backward propagation primitive. +struct lrn_backward : public primitive { + /// Primitive descriptor for an LRN backward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an LRN backward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param aalgorithm LRN algorithm kind: either + /// #dnnl::algorithm::lrn_across_channels, or + /// #dnnl::algorithm::lrn_within_channel. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param src_desc Source memory descriptor. + /// @param local_size Regularization local size. + /// @param alpha The alpha regularization parameter. + /// @param beta The beta regularization parameter. + /// @param k The k regularization parameter. + /// @param hint_fwd_pd Primitive descriptor for an LRN forward + /// propagation primitive. It is used as a hint for deciding which + /// memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + memory::dim local_size, float alpha, float beta, float k, + const lrn_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_lrn_backward_primitive_desc_create(&pd, + aengine.get(), convert_to_c(aalgorithm), + diff_src_desc.get(), diff_dst_desc.get(), src_desc.get(), + local_size, alpha, beta, k, hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the lrn backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for an LRN backward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for an LRN backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_alpha()const + float get_alpha() const { return base::get_alpha(); } + + /// @copydoc dnnl::primitive_desc_base::get_beta()const + float get_beta() const { return base::get_beta(); } + + /// @copydoc dnnl::primitive_desc_base::get_local_size()const + memory::dim get_local_size() const { return base::get_local_size(); } + + /// @copydoc dnnl::primitive_desc_base::get_k()const + float get_k() const { return base::get_k(); } + }; + + /// Default constructor. Produces an empty object. + lrn_backward() = default; + + /// Constructs an LRN backward propagation primitive. + /// @param pd Primitive descriptor for an LRN backward propagation + /// primitive. + lrn_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LRN backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LRN backward propagation + /// primitive. + /// @param cache_blob Cache blob. + lrn_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_lrn + +/// @addtogroup dnnl_api_eltwise Eltwise +/// +/// A primitive to perform elementwise operations such as the +/// rectifier linear unit (ReLU). +/// +/// Both forward and backward propagation primitives support in-place +/// operation; that is, src and dst can refer to the same memory for forward +/// propagation, and diff_dst and diff_src can refer to the same memory for +/// backward propagation. +/// +/// @warning +/// Because the original source data is required for backward propagation, +/// in-place forward propagation is not generally supported in the +/// training mode. However, for algorithms supporting destination as input +/// memory, dst can be used for the backward propagation, which makes it +/// possible to get performance benefit even in the training mode. +/// +/// @sa @ref dev_guide_eltwise in developer guide +/// +/// @{ + +/// Elementwise unary operation forward propagation primitive. +struct eltwise_forward : public primitive { + /// Primitive descriptor for an elementwise forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an elementwise forward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Elementwise algorithm kind. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + dst_desc, nullptr, nullptr, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an elementwise forward + /// propagation primitive with an alpha parameter. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Elementwise algorithm kind. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param alpha The alpha parameter for the elementwise operation. + /// Specific meaning depends on the algorithm. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, float alpha, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + dst_desc, &alpha, nullptr, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an elementwise forward + /// propagation primitive with an alpha and beta parameters. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Elementwise algorithm kind. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param alpha The alpha parameter for the elementwise operation. + /// Specific meaning depends on the algorithm. + /// @param beta The beta parameter for the elementwise operation. + /// Specific meaning depends on the algorithm. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, float alpha, float beta, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc, + dst_desc, &alpha, &beta, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an eltwise forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for an eltwise forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_alpha()const + float get_alpha() const { return base::get_alpha(); } + + /// @copydoc dnnl::primitive_desc_base::get_beta()const + float get_beta() const { return base::get_beta(); } + + private: + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, const float *alpha, + const float *beta, const primitive_attr &attr, + bool allow_empty) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_eltwise_forward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(aalgorithm), src_desc.get(), + dst_desc.get(), alpha ? *alpha : 0.0f, beta ? *beta : 0.0f, + attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the eltwise forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + eltwise_forward() = default; + + /// Constructs an eltwise forward propagation primitive. + /// @param pd Primitive descriptor for an eltwise forward propagation + /// primitive. + eltwise_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an eltwise forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an eltwise forward propagation + /// primitive. + /// @param cache_blob Cache blob. + eltwise_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Elementwise unary operation backward propagation primitive. +struct eltwise_backward : public primitive { + /// Primitive descriptor for eltwise backward propagation. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an elementwise backward + /// propagation primitive with an alpha parameter. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Elementwise algorithm kind. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param data_desc Destination memory descriptor if one of the + /// "use_dst_for_bwd" algorithms are used (such as + /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor + /// otherwise. + /// @param hint_fwd_pd Primitive descriptor for an elementwise + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::desc &data_desc, + const eltwise_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc, + data_desc, nullptr, nullptr, hint_fwd_pd, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for an elementwise backward + /// propagation primitive with an alpha parameter. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Elementwise algorithm kind. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param data_desc Destination memory descriptor if one of the + /// "use_dst_for_bwd" algorithms are used (such as + /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor + /// otherwise. + /// @param alpha The alpha parameter for the elementwise operation. + /// Specific meaning depends on the algorithm. + /// @param hint_fwd_pd Primitive descriptor for an elementwise + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::desc &data_desc, float alpha, + const eltwise_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc, + data_desc, &alpha, nullptr, hint_fwd_pd, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for an elementwise backward + /// propagation primitive with an alpha and beta parameters. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Elementwise algorithm kind. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param data_desc Destination memory descriptor if one of the + /// "use_dst_for_bwd" algorithms are used (such as + /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor + /// otherwise. + /// @param alpha The alpha parameter for the elementwise operation. + /// Specific meaning depends on the algorithm. + /// @param beta The beta parameter for the elementwise operation. + /// Specific meaning depends on the algorithm. + /// @param hint_fwd_pd Primitive descriptor for an elementwise + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::desc &data_desc, float alpha, float beta, + const eltwise_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc, + data_desc, &alpha, &beta, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an eltwise backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for an eltwise backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_alpha()const + float get_alpha() const { return base::get_alpha(); } + + /// @copydoc dnnl::primitive_desc_base::get_beta()const + float get_beta() const { return base::get_beta(); } + + private: + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::desc &data_desc, const float *alpha, + const float *beta, + const eltwise_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_eltwise_backward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aalgorithm), + diff_src_desc.get(), diff_dst_desc.get(), data_desc.get(), + alpha ? *alpha : 0.0f, beta ? *beta : 0.0f, + hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the eltwise backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + eltwise_backward() = default; + + /// Constructs an eltwise backward propagation primitive. + /// @param pd Primitive descriptor for an eltwise backward propagation + /// primitive. + eltwise_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an eltwise backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an eltwise backward propagation + /// primitive. + /// @param cache_blob Cache blob. + eltwise_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_eltwise + +/// @addtogroup dnnl_api_softmax Softmax +/// +/// A primitive to perform softmax. +/// +/// @sa @ref dev_guide_softmax in developer guide +/// +/// @{ + +/// Softmax forward propagation primitive. +struct softmax_forward : public primitive { + /// Primitive descriptor for a softmax forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a softmax forward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Softmax algorithm kind: either + /// #dnnl::algorithm::softmax_accurate, + /// or #dnnl::algorithm::softmax_log. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param axis Axis over which softmax is computed. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, int axis, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_softmax_forward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(aalgorithm), src_desc.get(), + dst_desc.get(), axis, attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the softmax forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a softmax forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a softmax forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_axis()const + int get_axis() const { return base::get_axis(); } + }; + + /// Default constructor. Produces an empty object. + softmax_forward() = default; + + /// Constructs a softmax forward propagation primitive. + /// @param pd Primitive descriptor for a softmax forward propagation + /// primitive. + softmax_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a softmax forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a softmax forward propagation + /// primitive. + /// @param cache_blob Cache blob. + softmax_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Softmax backward propagation primitive. +struct softmax_backward : public primitive { + /// Primitive descriptor for a softmax backward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a softmax backward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Softmax algorithm kind: either + /// #dnnl::algorithm::softmax_accurate, + /// or #dnnl::algorithm::softmax_log. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param axis Axis over which softmax is computed. + /// @param hint_fwd_pd Primitive descriptor for a softmax + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &dst_desc, + int axis, const softmax_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_softmax_backward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aalgorithm), + diff_src_desc.get(), diff_dst_desc.get(), dst_desc.get(), + axis, hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the softmax backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a softmax backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a softmax backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + dnnl::algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_axis()const + int get_axis() const { return base::get_axis(); } + }; + + /// Default constructor. Produces an empty object. + softmax_backward() = default; + + /// Constructs a softmax backward propagation primitive. + /// @param pd Primitive descriptor for a softmax backward propagation + /// primitive. + softmax_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a softmax backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a softmax backward propagation + /// primitive. + /// @param cache_blob Cache blob. + softmax_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_softmax + +/// @addtogroup dnnl_api_batch_normalization Batch Normalization +/// +/// A primitive to perform batch normalization. +/// +/// Both forward and backward propagation primitives support in-place +/// operation; that is, src and dst can refer to the same memory for forward +/// propagation, and diff_dst and diff_src can refer to the same memory for +/// backward propagation. +/// +/// The batch normalization primitives computations can be controlled by +/// specifying different @ref dnnl::normalization_flags values. For example, +/// batch normalization forward propagation can be configured to either +/// compute the mean and variance or take them as arguments. It can either +/// perform scaling and shifting using gamma and beta parameters or not. +/// Optionally, it can also perform a fused ReLU, which in case of training +/// would also require a workspace. +/// +/// @sa @ref dev_guide_batch_normalization in developer guide +/// +/// @{ + +/// Batch normalization forward propagation primitive. +struct batch_normalization_forward : public primitive { + /// Primitive descriptor for a batch normalization forward propagation + /// primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a batch normalization forward + /// propagation primitive. + /// + /// @note + /// In-place operation is supported: the dst can refer to the same + /// memory as the src. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param epsilon Batch normalization epsilon parameter. + /// @param flags Batch normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + float epsilon, normalization_flags flags, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_batch_normalization_forward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + src_desc.get(), dst_desc.get(), epsilon, + convert_to_c(flags), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the batch normalization forward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a batch normalization + /// forward propagation primitive from a C API primitive descriptor + /// that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a batch normalization + /// forward propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, + dnnl::primitive::kind::batch_normalization, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// Returns memory descriptor for mean. + /// @returns Memory descriptor for mean. + memory::desc mean_desc() const { return stat_desc(mean); } + + /// Returns memory descriptor for variance. + /// @returns Memory descriptor for variance. + memory::desc variance_desc() const { return stat_desc(var); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_epsilon()const + float get_epsilon() const { return base::get_epsilon(); } + + /// Returns normalization flags. + /// @return Normalization flags. + normalization_flags get_flags() const { + return base::get_flags(); + } + + private: + enum { + mean = 1, + var = 2, + }; + memory::desc stat_desc(int kind) const { + const bool use_global_stats + = (get_flags() & normalization_flags::use_global_stats) + != normalization_flags::none; + return query_md( + use_global_stats ? query::src_md : query::dst_md, kind); + } + }; + + /// Default constructor. Produces an empty object. + batch_normalization_forward() = default; + + /// Constructs a batch normalization forward propagation primitive. + /// @param pd Primitive descriptor for a batch normalization forward + /// propagation primitive. + batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a batch normalization forward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a batch normalization forward + /// propagation primitive. + /// @param cache_blob Cache blob. + batch_normalization_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Batch normalization backward propagation primitive. +struct batch_normalization_backward : public primitive { + /// Primitive descriptor for a batch normalization backward propagation + /// primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a batch normalization backward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward + /// (diffs for all parameters are computed in this case). + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param src_desc Source memory descriptor. + /// @param epsilon Batch normalization epsilon parameter. + /// @param flags Batch normalization flags (@ref + /// dnnl::normalization_flags). + /// @param hint_fwd_pd Primitive descriptor for a batch normalization + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + float epsilon, normalization_flags flags, + const batch_normalization_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_batch_normalization_backward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + diff_src_desc.get(), diff_dst_desc.get(), + src_desc.get(), epsilon, convert_to_c(flags), + hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the batch normalization backward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a batch normalization + /// backward propagation primitive from a C API primitive descriptor + /// that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a batch normalization + /// backward propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, + dnnl::primitive::kind::batch_normalization, + dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) { + } + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const + memory::desc diff_weights_desc() const { + return base::diff_weights_desc(0); + } + + /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const + memory::desc mean_desc() const { return query_md(query::src_md, 1); } + + /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const + memory::desc variance_desc() const { + return query_md(query::src_md, 2); + } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_epsilon()const + float get_epsilon() const { return base::get_epsilon(); } + + /// Returns normalization flags. + /// @return Normalization flags. + normalization_flags get_flags() const { + return base::get_flags(); + } + }; + + /// Default constructor. Produces an empty object. + batch_normalization_backward() = default; + + /// Constructs a batch normalization backward propagation primitive. + /// @param pd Primitive descriptor for a batch normalization backward + /// propagation primitive. + batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a batch normalization backward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a batch normalization backward + /// propagation primitive. + /// @param cache_blob Cache blob. + batch_normalization_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_batch_normalization + +/// @addtogroup dnnl_api_group_normalization Group Normalization +/// +/// A primitive to perform group normalization. +/// +/// Both forward and backward propagation primitives support in-place +/// operation; that is, src and dst can refer to the same memory for forward +/// propagation, and diff_dst and diff_src can refer to the same memory for +/// backward propagation. +/// +/// The group normalization primitives computations can be controlled by +/// specifying different @ref dnnl::normalization_flags values. For example, +/// group normalization forward propagation can be configured to either +/// compute the mean and variance or take them as arguments. It can either +/// perform scaling and shifting using gamma and beta parameters or not. +/// +/// @sa @ref dev_guide_group_normalization in developer guide +/// +/// @{ + +/// Group normalization forward propagation primitive. +struct group_normalization_forward : public primitive { + /// Primitive descriptor for a group normalization forward propagation + /// primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a group normalization forward + /// propagation primitive. + /// + /// @note + /// In-place operation is supported: the dst can refer to the same + /// memory as the src. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param groups Group normalization groups parameter. + /// @param epsilon Group normalization epsilon parameter. + /// @param flags Group normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + memory::dim groups, float epsilon, normalization_flags flags, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_group_normalization_forward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + src_desc.get(), dst_desc.get(), groups, epsilon, + convert_to_c(flags), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the group normalization forward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a group normalization + /// forward propagation primitive from a C API primitive descriptor + /// that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a group normalization + /// forward propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, + dnnl::primitive::kind::group_normalization, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// Returns memory descriptor for mean. + /// @returns Memory descriptor for mean. + memory::desc mean_desc() const { return stat_desc(mean); } + + /// Returns memory descriptor for variance. + /// @returns Memory descriptor for variance. + memory::desc variance_desc() const { return stat_desc(var); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_group_size()const + memory::dim get_group_size() const { return base::get_group_size(); } + + /// @copydoc dnnl::primitive_desc_base::get_epsilon()const + float get_epsilon() const { return base::get_epsilon(); } + + /// Returns normalization flags. + /// @return Normalization flags. + normalization_flags get_flags() const { + return base::get_flags(); + } + + private: + enum { + mean = 1, + var = 2, + }; + memory::desc stat_desc(int kind) const { + const bool use_global_stats + = (get_flags() & normalization_flags::use_global_stats) + != normalization_flags::none; + return query_md( + use_global_stats ? query::src_md : query::dst_md, kind); + } + }; + + /// Default constructor. Produces an empty object. + group_normalization_forward() = default; + + /// Constructs a group normalization forward propagation primitive. + /// @param pd Primitive descriptor for a group normalization forward + /// propagation primitive. + group_normalization_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a group normalization forward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a group normalization forward + /// propagation primitive. + /// @param cache_blob Cache blob. + group_normalization_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Group normalization backward propagation primitive. +struct group_normalization_backward : public primitive { + /// Primitive descriptor for a group normalization backward propagation + /// primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a group normalization backward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward + /// (diffs for all parameters are computed in this case). + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param src_desc Source memory descriptor. + /// @param groups Group normalization groups parameter. + /// @param epsilon Group normalization epsilon parameter. + /// @param flags Group normalization flags (@ref + /// dnnl::normalization_flags). + /// @param hint_fwd_pd Primitive descriptor for a group normalization + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + memory::dim groups, float epsilon, normalization_flags flags, + const group_normalization_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_group_normalization_backward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + diff_src_desc.get(), diff_dst_desc.get(), + src_desc.get(), groups, epsilon, + convert_to_c(flags), hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the group normalization backward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a group normalization + /// backward propagation primitive from a C API primitive descriptor + /// that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a group normalization + /// backward propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, + dnnl::primitive::kind::group_normalization, + dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) { + } + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const + memory::desc diff_weights_desc() const { + return base::diff_weights_desc(0); + } + + /// @copydoc dnnl::group_normalization_forward::primitive_desc::mean_desc()const + memory::desc mean_desc() const { return query_md(query::src_md, 1); } + + /// @copydoc dnnl::group_normalization_forward::primitive_desc::variance_desc()const + memory::desc variance_desc() const { + return query_md(query::src_md, 2); + } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_group_size()const + memory::dim get_group_size() const { return base::get_group_size(); } + + /// @copydoc dnnl::primitive_desc_base::get_epsilon()const + float get_epsilon() const { return base::get_epsilon(); } + + /// Returns normalization flags. + /// @return Normalization flags. + normalization_flags get_flags() const { + return base::get_flags(); + } + }; + + /// Default constructor. Produces an empty object. + group_normalization_backward() = default; + + /// Constructs a group normalization backward propagation primitive. + /// @param pd Primitive descriptor for a group normalization backward + /// propagation primitive. + group_normalization_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a group normalization backward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a group normalization backward + /// propagation primitive. + /// @param cache_blob Cache blob. + group_normalization_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_group_normalization + +/// @addtogroup dnnl_api_layer_normalization Layer Normalization +/// +/// A primitive to perform layer normalization. Normalization is performed +/// within the last logical dimension of data tensor. +/// +/// Both forward and backward propagation primitives support in-place +/// operation; that is, src and dst can refer to the same memory for forward +/// propagation, and diff_dst and diff_src can refer to the same memory for +/// backward propagation. +/// +/// The layer normalization primitives computations can be controlled by +/// specifying different @ref dnnl::normalization_flags values. For example, +/// layer normalization forward propagation can be configured to either +/// compute the mean and variance or take them as arguments. It can either +/// perform scaling and shifting using gamma and beta parameters or not. +/// +/// @sa @ref dev_guide_layer_normalization in developer guide +/// +/// @{ + +/// Layer normalization forward propagation primitive. +struct layer_normalization_forward : public primitive { + /// Primitive descriptor for a layer normalization forward propagation + /// primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a layer normalization forward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param stat_desc Statistics memory descriptors. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + const memory::desc &stat_desc, float epsilon, + normalization_flags flags, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, + &stat_desc, memory::data_type::f32, epsilon, flags, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for a layer normalization forward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + float epsilon, normalization_flags flags, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr, + memory::data_type::f32, epsilon, flags, attr, allow_empty) { + } + + /// Constructs a primitive descriptor for a layer normalization forward + /// propagation primitive with a user-provided data type for the scale + /// and shift memory objects. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param stat_desc Statistics memory descriptors. + /// @param scale_shift_data_type Data type of scale and shift memory. + /// If neither scale nor shift flag are specified the parameter + /// is ignored. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + const memory::desc &stat_desc, + memory::data_type scale_shift_data_type, float epsilon, + normalization_flags flags, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, + &stat_desc, scale_shift_data_type, epsilon, flags, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for a layer normalization forward + /// propagation primitive with a user-provided data type for the scale + /// and shift memory objects. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param scale_shift_data_type Data type of scale and shift memory. + /// If neither scale nor shift flag are specified the parameter + /// is ignored. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + memory::data_type scale_shift_data_type, float epsilon, + normalization_flags flags, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr, + scale_shift_data_type, epsilon, flags, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a layer normalization + /// forward propagation primitive from a C API primitive descriptor + /// that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a layer normalization + /// forward propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, + dnnl::primitive::kind::layer_normalization, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const + memory::desc mean_desc() const { return stat_desc(mean); } + + /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const + memory::desc variance_desc() const { return stat_desc(var); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_epsilon()const + float get_epsilon() const { return base::get_epsilon(); } + + /// Returns normalization flags. + /// @return Normalization flags. + normalization_flags get_flags() const { + return base::get_flags(); + } + + private: + enum { + mean = 1, + var = 2, + }; + memory::desc stat_desc(int kind) const { + const bool use_global_stats + = (get_flags() & normalization_flags::use_global_stats) + != normalization_flags::none; + return query_md( + use_global_stats ? query::src_md : query::dst_md, kind); + } + + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + const memory::desc *stat_desc, + memory::data_type scale_shift_data_type, float epsilon, + normalization_flags flags, const primitive_attr &attr, + bool allow_empty) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_layer_normalization_forward_primitive_desc_create_v2( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + src_desc.get(), dst_desc.get(), + optional_arg(stat_desc), + memory::convert_to_c(scale_shift_data_type), + epsilon, convert_to_c(flags), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the layer normalization forward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + layer_normalization_forward() = default; + + /// Constructs a layer normalization forward propagation primitive. + /// @param pd Primitive descriptor for a layer normalization forward + /// propagation primitive. + layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a layer normalization forward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a layer normalization forward + /// propagation primitive. + /// @param cache_blob Cache blob. + layer_normalization_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Layer normalization backward propagation primitive. +struct layer_normalization_backward : public primitive { + /// Primitive descriptor for a layer normalization backward propagation + /// primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a layer normalization backward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward + /// (diffs for all parameters are computed in this case). + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param src_desc Source memory descriptor. + /// @param stat_desc Statistics memory descriptors. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param hint_fwd_pd Primitive descriptor for a layer normalization + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + const memory::desc &stat_desc, float epsilon, + normalization_flags flags, + const layer_normalization_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc, + src_desc, &stat_desc, memory::data_type::f32, + memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for a layer normalization backward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward + /// (diffs for all parameters are computed in this case). + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param src_desc Source memory descriptor. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param hint_fwd_pd Primitive descriptor for a layer normalization + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + float epsilon, normalization_flags flags, + const layer_normalization_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc, + src_desc, nullptr, memory::data_type::f32, + memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for a layer normalization backward + /// propagation primitive with a user-provided data type for the scale + /// and shift memory objects. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward + /// (diffs for all parameters are computed in this case). + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param src_desc Source memory descriptor. + /// @param stat_desc Statistics memory descriptors. + /// @param diff_scale_shift_data_type Data type of diff scale and shift + /// memory. If neither scale nor shift flag are specified the + /// parameter is ignored. + /// @param scale_shift_data_type Data type of scale and shift memory. + /// If neither scale nor shift flag are specified the parameter + /// is ignored. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param hint_fwd_pd Primitive descriptor for a layer normalization + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + const memory::desc &stat_desc, + memory::data_type diff_scale_shift_data_type, + memory::data_type scale_shift_data_type, float epsilon, + normalization_flags flags, + const layer_normalization_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc, + src_desc, &stat_desc, diff_scale_shift_data_type, + scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for a layer normalization backward + /// propagation primitive with a user-provided data type for the scale + /// and shift memory objects. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward + /// (diffs for all parameters are computed in this case). + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param src_desc Source memory descriptor. + /// @param diff_scale_shift_data_type Data type of diff scale and shift + /// memory. If neither scale nor shift flag are specified the + /// parameter is ignored. + /// @param scale_shift_data_type Data type of scale and shift memory. + /// If neither scale nor shift flag are specified the parameter + /// is ignored. + /// @param epsilon Layer normalization epsilon parameter. + /// @param flags Layer normalization flags (@ref + /// dnnl::normalization_flags). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param hint_fwd_pd Primitive descriptor for a layer normalization + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + memory::data_type diff_scale_shift_data_type, + memory::data_type scale_shift_data_type, float epsilon, + normalization_flags flags, + const layer_normalization_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc, + src_desc, nullptr, diff_scale_shift_data_type, + scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for a layer normalization + /// backward propagation primitive from a C API primitive descriptor + /// that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a layer normalization + /// backward propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, + dnnl::primitive::kind::layer_normalization, + dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) { + } + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const + memory::desc diff_weights_desc() const { + return base::diff_weights_desc(0); + } + + /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const + memory::desc mean_desc() const { return query_md(query::src_md, 1); } + + /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const + memory::desc variance_desc() const { + return query_md(query::src_md, 2); + } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_epsilon()const + float get_epsilon() const { return base::get_epsilon(); } + + /// Returns normalization flags. + /// @return Normalization flags. + normalization_flags get_flags() const { + return base::get_flags(); + } + + private: + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::desc &src_desc, + const memory::desc *stat_desc, + memory::data_type diff_scale_shift_data_type, + memory::data_type scale_shift_data_type, float epsilon, + normalization_flags flags, + const layer_normalization_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_layer_normalization_backward_primitive_desc_create_v2( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + diff_src_desc.get(), diff_dst_desc.get(), + src_desc.get(), optional_arg(stat_desc), + memory::convert_to_c(diff_scale_shift_data_type), + memory::convert_to_c(scale_shift_data_type), + epsilon, convert_to_c(flags), hint_fwd_pd.get(), + attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the layer normalization backward propagation " + "primitive. Run workload with environment variable " + "ONEDNN_VERBOSE=all to get additional diagnostic " + "information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + layer_normalization_backward() = default; + + /// Constructs a layer normalization backward propagation primitive. + /// @param pd Primitive descriptor for a layer normalization backward + /// propagation primitive. + layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a layer normalization backward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a layer normalization backward + /// propagation primitive. + /// @param cache_blob Cache blob. + layer_normalization_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_layer_normalization + +/// @addtogroup dnnl_api_inner_product Inner Product +/// +/// A primitive to compute an inner product. +/// +/// @sa @ref dev_guide_inner_product in developer guide +/// +/// @{ + +/// Inner product forward propagation primitive. +struct inner_product_forward : public primitive { + /// Primitive descriptor for an inner product forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an inner product forward + /// propagation primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Memory descriptor for src. + /// @param weights_desc Memory descriptor for weights. + /// @param bias_desc Memory descriptor for bias. + /// @param dst_desc Memory descriptor for dst. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &weights_desc, + const memory::desc &bias_desc, const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, src_desc, weights_desc, + &bias_desc, dst_desc, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an inner product forward + /// propagation primitive. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Memory descriptor for src. + /// @param weights_desc Memory descriptor for weights. + /// @param dst_desc Memory descriptor for dst. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &weights_desc, + const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, src_desc, weights_desc, + nullptr, dst_desc, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an inner product forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for an inner product forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const + memory::desc bias_desc() const { return base::weights_desc(1); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + private: + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &weights_desc, + const memory::desc *bias_desc, const memory::desc &dst_desc, + const primitive_attr &attr, bool allow_empty) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_inner_product_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + src_desc.get(), weights_desc.get(), + optional_arg(bias_desc), dst_desc.get(), + attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the inner product forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + inner_product_forward() = default; + + /// Constructs an inner product forward propagation primitive. + /// @param pd Primitive descriptor for an inner product forward + /// propagation primitive. + inner_product_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an inner product forward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for an inner product forward + /// propagation primitive. + /// @param cache_blob Cache blob. + inner_product_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Inner product backward propagation primitive. +struct inner_product_backward_data : public primitive { + /// Primitive descriptor for an inner product backward propagation + /// primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an inner product backward + /// propagation primitive. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param diff_src_desc Memory descriptor for diff src. + /// @param weights_desc Memory descriptor for weights. + /// @param diff_dst_desc Memory descriptor for diff dst. + /// @param hint_fwd_pd Primitive descriptor for an inner product + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const inner_product_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_inner_product_backward_data_primitive_desc_create( + &pd, aengine.get(), diff_src_desc.get(), + weights_desc.get(), diff_dst_desc.get(), + hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the inner product backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for an inner product backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for an inner product backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { return base::weights_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + }; + + /// Default constructor. Produces an empty object. + inner_product_backward_data() = default; + + /// Constructs an inner product backward propagation primitive. + /// @param pd Primitive descriptor for an inner product backward + /// propagation primitive. + inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an inner product backward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for an inner product backward + /// propagation primitive. + /// @param cache_blob Cache blob. + inner_product_backward_data( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Inner product weights gradient primitive. +struct inner_product_backward_weights : public primitive { + /// Primitive descriptor for an inner product weights gradient primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an inner product weights + /// update primitive with bias. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param src_desc Memory descriptor for src. + /// @param diff_weights_desc Memory descriptor for diff weights. + /// @param diff_bias_desc Memory descriptor for diff bias. + /// @param diff_dst_desc Memory descriptor for diff dst. + /// @param hint_fwd_pd Primitive descriptor for an inner product + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const inner_product_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, src_desc, diff_weights_desc, + &diff_bias_desc, diff_dst_desc, hint_fwd_pd, attr, + allow_empty) {} + + /// Constructs a primitive descriptor for an inner product weights + /// update primitive. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param src_desc Memory descriptor for src. + /// @param diff_weights_desc Memory descriptor for diff weights. + /// @param diff_dst_desc Memory descriptor for diff dst. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param hint_fwd_pd Primitive descriptor for an inner product + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const inner_product_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, src_desc, diff_weights_desc, nullptr, + diff_dst_desc, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an inner product weights + /// update primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for an inner product weights + /// gradient primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product, + dnnl::prop_kind::backward_weights) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const + memory::desc diff_weights_desc() const { + return base::diff_weights_desc(0); + } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return base::diff_weights_desc(1); + } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + private: + primitive_desc(const engine &aengine, const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc *diff_bias_desc, + const memory::desc &diff_dst_desc, + const inner_product_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_inner_product_backward_weights_primitive_desc_create( + &pd, aengine.get(), src_desc.get(), + diff_weights_desc.get(), + optional_arg(diff_bias_desc), diff_dst_desc.get(), + hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the inner product weights gradient primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + inner_product_backward_weights() = default; + + /// Constructs an inner product weights gradient primitive. + /// @param pd Primitive descriptor for an inner product weights gradient + /// primitive. + inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an inner product weights gradient primitive from a cache + /// blob. + /// @param pd Primitive descriptor for an inner product weights gradient + /// primitive. + /// @param cache_blob Cache blob. + inner_product_backward_weights( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_inner_product + +/// @addtogroup dnnl_api_rnn RNN +/// +/// A primitive to compute recurrent neural network layers. +/// +/// @sa @ref dev_guide_rnn in developer guide +/// +/// @{ + +/// Base class for primitive descriptors for RNN primitives. +struct rnn_primitive_desc_base : public primitive_desc { + using primitive_desc::primitive_desc; + + /// Default constructor. Produces an empty object. + rnn_primitive_desc_base() = default; + + /// Constructs an RNN primitive descriptor base from a C API primitive + /// descriptor while checking that it actually describes the expected + /// primitive by comparing propagation and primitive kinds. + /// + /// @param pd C API primitive descriptor. + /// @param aprop_kind Expected propagation kind. + /// @param cell_kind Expected cell kind. + rnn_primitive_desc_base(dnnl_primitive_desc_t pd, + dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind) + : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {} + + /// Returns source layer memory descriptor. + /// @returns Source layer memory descriptor. + memory::desc src_layer_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER); + } + + /// Returns AUGRU attention memory descriptor. + /// @returns AUGRU attention memory descriptor. + memory::desc augru_attention_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION); + } + + /// Returns source iteration memory descriptor. + /// @returns Source iteration memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// source iteration parameter. + memory::desc src_iter_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER); + } + + /// Returns source recurrent cell state memory descriptor. + /// @returns Source recurrent cell state memory descriptor. + memory::desc src_iter_c_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C); + } + + /// Returns weights layer memory descriptor. + /// @returns Weights layer memory descriptor. + memory::desc weights_layer_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER); + } + + /// Returns weights iteration memory descriptor. + /// @returns Weights iteration memory descriptor. + memory::desc weights_iter_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER); + } + + /// Returns weights peephole memory descriptor. + /// @returns Weights peephole memory descriptor. + memory::desc weights_peephole_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE); + } + + /// Returns weights projection memory descriptor. + /// @returns Weights projection memory descriptor. + memory::desc weights_projection_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION); + } + + /// Returns bias memory descriptor. + /// @returns Bias memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// bias parameter. + memory::desc bias_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS); + } + + /// Returns destination layer memory descriptor. + /// @returns Destination layer memory descriptor. + memory::desc dst_layer_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER); + } + + /// Returns destination iteration memory descriptor. + /// @returns Destination iteration memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// destination iteration parameter. + memory::desc dst_iter_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER); + } + + /// Returns destination recurrent cell state memory descriptor. + /// @returns Destination recurrent cell state memory descriptor. + memory::desc dst_iter_c_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C); + } + + /// Returns diff source layer memory descriptor. + /// @returns Diff source layer memory descriptor. + memory::desc diff_src_layer_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER); + } + + /// Returns diff AUGRU attention memory descriptor. + /// @returns Diff AUGRU attention memory descriptor. + memory::desc diff_augru_attention_desc() const { + return base::query_md( + query::exec_arg_md, DNNL_ARG_DIFF_AUGRU_ATTENTION); + } + + /// Returns diff source iteration memory descriptor. + /// @returns Diff source iteration memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff source iteration parameter. + memory::desc diff_src_iter_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER); + } + + /// Returns diff source recurrent cell state memory descriptor. + /// @returns Diff source recurrent cell state memory descriptor. + memory::desc diff_src_iter_c_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C); + } + + /// Returns diff weights layer memory descriptor. + /// @returns Diff weights layer memory descriptor. + memory::desc diff_weights_layer_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER); + } + + /// Returns diff weights iteration memory descriptor. + /// @returns Diff weights iteration memory descriptor. + memory::desc diff_weights_iter_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER); + } + + /// Returns diff weights peephole memory descriptor. + /// @returns Diff weights peephole memory descriptor. + memory::desc diff_weights_peephole_desc() const { + return base::query_md( + query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE); + } + + /// Returns diff weights projection memory descriptor. + /// @returns Diff weights projection memory descriptor. + memory::desc diff_weights_projection_desc() const { + return base::query_md( + query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION); + } + + /// Returns diff bias memory descriptor. + /// @returns Diff bias memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff bias parameter. + memory::desc diff_bias_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS); + } + + /// Returns diff destination layer memory descriptor. + /// @returns Diff destination layer memory descriptor. + memory::desc diff_dst_layer_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER); + } + + /// Returns diff destination iteration memory descriptor. + /// @returns Diff destination iteration memory descriptor. + /// @returns A zero memory descriptor if the primitive does not have a + /// diff destination iteration parameter. + memory::desc diff_dst_iter_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER); + } + + /// Returns diff destination recurrent cell state memory descriptor. + /// @returns Diff destination recurrent cell state memory descriptor. + memory::desc diff_dst_iter_c_desc() const { + return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C); + } + +protected: + using rnn_base = rnn_primitive_desc_base; + + // (Deliberately not using doxygen comments) + // + // Constructs an RNN primitive descriptor base from a C API primitive + // descriptor while checking that it actually describes the expected + // primitive by comparing propagation and primitive kinds. Caller can + // pass two options propagation kinds. This is typically used to check + // that propagation kind is inference or training forward propagation. + // + // @param pd C API primitive descriptor. + // @param prop_kind1 Expected propagation kind. + // @param prop_kind2 Expected propagation kind. + // @param cell_kind Expected cell kind. + rnn_primitive_desc_base(dnnl_primitive_desc_t pd, + dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2, + dnnl::algorithm cell_kind) { + + dnnl_status_t rc; + + dnnl_primitive_kind_t q_primitive_kind; + rc = dnnl_primitive_desc_query( + pd, dnnl_query_primitive_kind, 0, &q_primitive_kind); + error::wrap_c_api(rc, + "could not retrieve a primitive kind from a primitive " + "descriptor for an RNN primitive"); + + dnnl_prop_kind_t q_prop_kind; + rc = dnnl_primitive_desc_query( + pd, dnnl_query_prop_kind, 0, &q_prop_kind); + error::wrap_c_api(rc, + "could not retrieve a propagation kind from a primitive " + "descriptor for an RNN primitive"); + + dnnl_alg_kind_t q_cell_kind; + rc = dnnl_primitive_desc_query( + pd, dnnl_query_cell_kind, 0, &q_cell_kind); + error::wrap_c_api(rc, + "could not retrieve a cell kind from a primitive descriptor " + "for an RNN primitive"); + + dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1); + dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2); + dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind); + + bool ok = q_primitive_kind == dnnl_rnn + && (q_prop_kind == c_prop_kind1 || q_prop_kind == c_prop_kind2) + && q_cell_kind == c_cell_kind; + + if (!ok) + DNNL_THROW_ERROR(dnnl_invalid_arguments, + "mismatch between expected and provided descriptors for an " + "RNN primitive"); + + reset_with_clone(pd); + } + + // Constructs an RNN forward propagation primitive descriptor base for + // any cell kind. + rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind, + prop_kind aprop_kind, algorithm activation, rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc *src_iter_c_desc, + const memory::desc *attention_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc *weights_peephole_desc, + const memory::desc *weights_projection_desc, + const memory::desc &bias_desc, const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc *dst_iter_c_desc, rnn_flags flags, float alpha, + float beta, const primitive_attr &attr, bool allow_empty) { + + dnnl_status_t status = dnnl_success; + const char *msg + = "could not create a primitive descriptor for a requested " + "cell kind"; + + dnnl_primitive_desc_t pd = nullptr; + switch (cell_kind) { + case algorithm::vanilla_rnn: + status = dnnl_vanilla_rnn_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(activation), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), weights_layer_desc.get(), + weights_iter_desc.get(), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + convert_to_c(flags), alpha, beta, attr.get()); + msg = "could not create a primitive descriptor for " + "the vanilla RNN forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; + break; + case algorithm::vanilla_lstm: + status = dnnl_lstm_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), optional_arg(src_iter_c_desc), + weights_layer_desc.get(), weights_iter_desc.get(), + optional_arg(weights_peephole_desc), + optional_arg(weights_projection_desc), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + optional_arg(dst_iter_c_desc), convert_to_c(flags), + attr.get()); + msg = "could not create a primitive descriptor for " + "the LSTM forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; + break; + case algorithm::vanilla_gru: + status = dnnl_gru_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), weights_layer_desc.get(), + weights_iter_desc.get(), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + convert_to_c(flags), attr.get()); + msg = "could not create a primitive descriptor for " + "the GRU forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; + break; + case algorithm::lbr_gru: + status = dnnl_lbr_gru_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), weights_layer_desc.get(), + weights_iter_desc.get(), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + convert_to_c(flags), attr.get()); + msg = "could not create a primitive descriptor for " + "the LBR GRU forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; + break; + case algorithm::vanilla_augru: + status = dnnl_augru_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), optional_arg(attention_desc), + weights_layer_desc.get(), weights_iter_desc.get(), + bias_desc.get(), dst_layer_desc.get(), + dst_iter_desc.get(), convert_to_c(flags), attr.get()); + msg = "could not create a primitive descriptor for " + "the AUGRU forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; + break; + case algorithm::lbr_augru: + status = dnnl_lbr_augru_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), optional_arg(attention_desc), + weights_layer_desc.get(), weights_iter_desc.get(), + bias_desc.get(), dst_layer_desc.get(), + dst_iter_desc.get(), convert_to_c(flags), attr.get()); + msg = "could not create a primitive descriptor for " + "the LBR AUGRU forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; + break; + default: status = dnnl_unimplemented; + } + + if (!allow_empty) error::wrap_c_api(status, msg); + reset(pd); + } + + // Constructs an RNN backward propagation primitive descriptor base for + // any cell kind. + rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind, + prop_kind aprop_kind, algorithm activation, rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc *src_iter_c_desc, + const memory::desc *attention_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc *weights_peephole_desc, + const memory::desc *weights_projection_desc, + const memory::desc &bias_desc, const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc *dst_iter_c_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc *diff_src_iter_c_desc, + const memory::desc *diff_attention_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc *diff_weights_peephole_desc, + const memory::desc *diff_weights_projection_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const memory::desc *diff_dst_iter_c_desc, rnn_flags flags, + float alpha, float beta, const rnn_primitive_desc_base &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + dnnl_status_t status = dnnl_success; + const char *msg = ""; + + dnnl_primitive_desc_t pd = nullptr; + switch (cell_kind) { + case algorithm::vanilla_rnn: + status = dnnl_vanilla_rnn_backward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(activation), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), weights_layer_desc.get(), + weights_iter_desc.get(), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + diff_src_layer_desc.get(), diff_src_iter_desc.get(), + diff_weights_layer_desc.get(), + diff_weights_iter_desc.get(), diff_bias_desc.get(), + diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), + convert_to_c(flags), alpha, beta, hint_fwd_pd.get(), + attr.get()); + msg = "could not create a primitive descriptor for " + "the vanilla RNN backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; + break; + case algorithm::vanilla_lstm: + status = dnnl_lstm_backward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), optional_arg(src_iter_c_desc), + weights_layer_desc.get(), weights_iter_desc.get(), + optional_arg(weights_peephole_desc), + optional_arg(weights_projection_desc), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + optional_arg(dst_iter_c_desc), + diff_src_layer_desc.get(), diff_src_iter_desc.get(), + optional_arg(diff_src_iter_c_desc), + diff_weights_layer_desc.get(), + diff_weights_iter_desc.get(), + optional_arg(diff_weights_peephole_desc), + optional_arg(diff_weights_projection_desc), + diff_bias_desc.get(), diff_dst_layer_desc.get(), + diff_dst_iter_desc.get(), + optional_arg(diff_dst_iter_c_desc), convert_to_c(flags), + hint_fwd_pd.get(), attr.get()); + msg = "could not create a primitive descriptor for " + "the LSTM backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; + break; + case algorithm::vanilla_gru: + status = dnnl_gru_backward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), weights_layer_desc.get(), + weights_iter_desc.get(), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + diff_src_layer_desc.get(), diff_src_iter_desc.get(), + diff_weights_layer_desc.get(), + diff_weights_iter_desc.get(), diff_bias_desc.get(), + diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), + convert_to_c(flags), hint_fwd_pd.get(), attr.get()); + msg = "could not create a primitive descriptor for " + "the GRU backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; + break; + case algorithm::lbr_gru: + status = dnnl_lbr_gru_backward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), weights_layer_desc.get(), + weights_iter_desc.get(), bias_desc.get(), + dst_layer_desc.get(), dst_iter_desc.get(), + diff_src_layer_desc.get(), diff_src_iter_desc.get(), + diff_weights_layer_desc.get(), + diff_weights_iter_desc.get(), diff_bias_desc.get(), + diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), + convert_to_c(flags), hint_fwd_pd.get(), attr.get()); + msg = "could not create a primitive descriptor for " + "the LBR GRU backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; + break; + case algorithm::vanilla_augru: + status = dnnl_augru_backward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), optional_arg(attention_desc), + weights_layer_desc.get(), weights_iter_desc.get(), + bias_desc.get(), dst_layer_desc.get(), + dst_iter_desc.get(), diff_src_layer_desc.get(), + diff_src_iter_desc.get(), + optional_arg(diff_attention_desc), + diff_weights_layer_desc.get(), + diff_weights_iter_desc.get(), diff_bias_desc.get(), + diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), + convert_to_c(flags), hint_fwd_pd.get(), attr.get()); + msg = "could not create a primitive descriptor for " + "the AUGRU backward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."; + break; + case algorithm::lbr_augru: + status = dnnl_lbr_augru_backward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + dnnl::convert_to_c(direction), src_layer_desc.get(), + src_iter_desc.get(), optional_arg(attention_desc), + weights_layer_desc.get(), weights_iter_desc.get(), + bias_desc.get(), dst_layer_desc.get(), + dst_iter_desc.get(), diff_src_layer_desc.get(), + diff_src_iter_desc.get(), + optional_arg(diff_attention_desc), + diff_weights_layer_desc.get(), + diff_weights_iter_desc.get(), diff_bias_desc.get(), + diff_dst_layer_desc.get(), diff_dst_iter_desc.get(), + convert_to_c(flags), hint_fwd_pd.get(), attr.get()); + msg = "could not create a primitive descriptor for " + "the LBR AUGRU backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."; + break; + default: status = dnnl_unimplemented; + } + if (!allow_empty) error::wrap_c_api(status, msg); + reset(pd); + } +}; + +/// Vanilla RNN forward propagation primitive. +struct vanilla_rnn_forward : public primitive { + /// Primitive descriptor for a vanilla RNN forward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a vanilla RNN forward + /// propagation primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc. + /// + /// This would then indicate that the RNN forward propagation primitive + /// should not use them and should default to zero values instead. + /// + /// @note + /// All memory descriptors except @p src_iter_desc can be + /// initialized with an #dnnl::memory::format_tag::any value of @p + /// format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param activation Activation kind. Possible values are + /// #dnnl::algorithm::eltwise_relu, + /// #dnnl::algorithm::eltwise_tanh, or + /// #dnnl::algorithm::eltwise_logistic. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm activation, rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, + aprop_kind, activation, direction, src_layer_desc, + src_iter_desc, nullptr, nullptr, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, + 0.0f, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a vanilla RNN forward + /// propagation primitive with alpha parameter. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc. + /// + /// This would then indicate that the RNN forward propagation primitive + /// should not use them and should default to zero values instead. + /// + /// @note + /// All memory descriptors except @p src_iter_desc can be + /// initialized with an #dnnl::memory::format_tag::any value of @p + /// format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param activation Activation kind. Possible values are + /// #dnnl::algorithm::eltwise_relu, + /// #dnnl::algorithm::eltwise_tanh, or + /// #dnnl::algorithm::eltwise_logistic. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param alpha Negative slope if activation is + /// #dnnl::algorithm::eltwise_relu. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm activation, rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, float alpha, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, + aprop_kind, activation, direction, src_layer_desc, + src_iter_desc, nullptr, nullptr, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, + alpha, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a vanilla RNN forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a vanilla RNN forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference, + dnnl::algorithm::vanilla_rnn) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const + algorithm get_activation_kind() const { + return base::get_activation_kind(); + } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + + /// @copydoc dnnl::primitive_desc_base::get_alpha()const + float get_alpha() const { return base::get_alpha(); } + + /// @copydoc dnnl::primitive_desc_base::get_beta()const + float get_beta() const { return base::get_beta(); } + }; + + /// Default constructor. Produces an empty object. + vanilla_rnn_forward() = default; + + /// Constructs a vanilla RNN forward propagation primitive. + /// @param pd Primitive descriptor for a vanilla RNN forward + /// propagation primitive. + vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a vanilla RNN forward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a vanilla RNN forward + /// propagation primitive. + /// @param cache_blob Cache blob. + vanilla_rnn_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Vanilla RNN backward propagation primitive. +struct vanilla_rnn_backward : public primitive { + /// Primitive descriptor for an RNN backward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a vanilla RNN backward + /// propagation primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p diff_src_iter_desc, + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p diff_dst_iter_desc. + /// + /// This would then indicate that the RNN backward propagation + /// primitive should not use the respective data and should use zero + /// values instead. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param activation Activation kind. Possible values are + /// #dnnl::algorithm::eltwise_relu, + /// #dnnl::algorithm::eltwise_tanh, or + /// #dnnl::algorithm::eltwise_logistic. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm activation, rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, + aprop_kind, activation, direction, src_layer_desc, + src_iter_desc, nullptr, nullptr, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, + diff_src_iter_desc, nullptr, nullptr, + diff_weights_layer_desc, diff_weights_iter_desc, nullptr, + nullptr, diff_bias_desc, diff_dst_layer_desc, + diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a vanilla RNN backward + /// propagation primitive with an alpha parameter. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p diff_src_iter_desc, + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p diff_dst_iter_desc. + /// + /// This would then indicate that the RNN backward propagation + /// primitive should not use the respective data and should use zero + /// values instead. + /// + /// @note + /// All the memory descriptors may be initialized with the + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param activation Activation kind. Possible values are + /// #dnnl::algorithm::eltwise_relu, + /// #dnnl::algorithm::eltwise_tanh, or + /// #dnnl::algorithm::eltwise_logistic. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param alpha Negative slope if activation is + /// #dnnl::algorithm::eltwise_relu. + /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm activation, rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, float alpha, + const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn, + aprop_kind, activation, direction, src_layer_desc, + src_iter_desc, nullptr, nullptr, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, + diff_src_iter_desc, nullptr, nullptr, + diff_weights_layer_desc, diff_weights_iter_desc, nullptr, + nullptr, diff_bias_desc, diff_dst_layer_desc, + diff_dst_iter_desc, nullptr, rnn_flags::undef, alpha, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a vanilla RNN backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a vanilla RNN backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, + dnnl::algorithm::vanilla_rnn) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const + memory::desc diff_src_layer_desc() const { + return rnn_base::diff_src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const + memory::desc diff_src_iter_desc() const { + return rnn_base::diff_src_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const + memory::desc diff_weights_layer_desc() const { + return rnn_base::diff_weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const + memory::desc diff_weights_iter_desc() const { + return rnn_base::diff_weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return rnn_base::diff_bias_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const + memory::desc diff_dst_layer_desc() const { + return rnn_base::diff_dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const + memory::desc diff_dst_iter_desc() const { + return rnn_base::diff_dst_iter_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const + algorithm get_activation_kind() const { + return base::get_activation_kind(); + } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + + /// @copydoc dnnl::primitive_desc_base::get_alpha()const + float get_alpha() const { return base::get_alpha(); } + + /// @copydoc dnnl::primitive_desc_base::get_beta()const + float get_beta() const { return base::get_beta(); } + }; + + /// Default constructor. Produces an empty object. + vanilla_rnn_backward() = default; + + /// Constructs a vanilla RNN backward propagation primitive. + /// @param pd Primitive descriptor for a vanilla RNN backward + /// propagation primitive. + vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a vanilla RNN backward propagation primitive from + /// a cache blob. + /// @param pd Primitive descriptor for a vanilla RNN backward + /// propagation primitive. + /// @param cache_blob Cache blob. + vanilla_rnn_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// LSTM forward propagation primitive. +struct lstm_forward : public primitive { + /// Primitive descriptor for an LSTM forward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an LSTM (with or without + /// peephole and with or without projection) forward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p src_iter_c_desc, + /// - @p weights_peephole_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc together with @p dst_iter_c_desc. + /// + /// This would then indicate that the LSTM forward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// The @p weights_projection_desc may point to a zero memory + /// descriptor. This would then indicate that the LSTM doesn't have + /// recurrent projection layer. + /// + /// @note + /// All memory descriptors can be initialized with an + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param src_iter_c_desc Memory descriptor for the input recurrent + /// cell state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param weights_peephole_desc Memory descriptor for the weights + /// applied to the cell states (according to the Peephole LSTM + /// formula). + /// @param weights_projection_desc Memory descriptor for the weights + /// applied to the hidden states to get the recurrent projection + /// (according to the Projection LSTM formula). + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param dst_iter_c_desc Memory descriptor for the output recurrent + /// cell state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &src_iter_c_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &weights_peephole_desc, + const memory::desc &weights_projection_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &dst_iter_c_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, &src_iter_c_desc, nullptr, + weights_layer_desc, weights_iter_desc, + &weights_peephole_desc, &weights_projection_desc, bias_desc, + dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, + rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an LSTM (with or without + /// peephole) forward propagation primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p src_iter_c_desc, + /// - @p weights_peephole_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc together with @p dst_iter_c_desc. + /// + /// This would then indicate that the LSTM forward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors can be initialized with an + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param src_iter_c_desc Memory descriptor for the input recurrent + /// cell state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param weights_peephole_desc Memory descriptor for the weights + /// applied to the cell states (according to the Peephole LSTM + /// formula). + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param dst_iter_c_desc Memory descriptor for the output recurrent + /// cell state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &src_iter_c_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &weights_peephole_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &dst_iter_c_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, &src_iter_c_desc, nullptr, + weights_layer_desc, weights_iter_desc, + &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc, + dst_iter_desc, &dst_iter_c_desc, rnn_flags::undef, 0.0f, + 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an LSTM forward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p src_iter_c_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc together with @p dst_iter_c_desc. + /// + /// This would then indicate that the LSTM forward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors can be initialized with an + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param src_iter_c_desc Memory descriptor for the input recurrent + /// cell state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param dst_iter_c_desc Memory descriptor for the output recurrent + /// cell state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &src_iter_c_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &dst_iter_c_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, &src_iter_c_desc, nullptr, + weights_layer_desc, weights_iter_desc, nullptr, nullptr, + bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, + rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an LSTM forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for an LSTM forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference, + dnnl::algorithm::vanilla_lstm) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_c_desc() const { + return rnn_base::src_iter_c_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const + memory::desc weights_peephole_desc() const { + return rnn_base::weights_peephole_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const + memory::desc weights_projection_desc() const { + return rnn_base::weights_projection_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc dst_iter_c_desc() const { + return rnn_base::dst_iter_c_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + lstm_forward() = default; + + /// Constructs an LSTM forward propagation primitive. + /// @param pd Primitive descriptor for an LSTM forward propagation + /// primitive. + lstm_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LSTM forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LSTM forward propagation + /// primitive. + /// @param cache_blob Cache blob. + lstm_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// LSTM backward propagation primitive. +struct lstm_backward : public primitive { + /// Primitive descriptor for an LSTM backward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs an LSTM (with or without peephole and with or without + /// projection) primitive descriptor for backward propagation + /// using @p prop_kind, @p direction, and memory descriptors. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p src_iter_c_desc, + /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc, + /// - @p weights_peephole_desc together with + /// @p diff_weights_peephole_desc + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p dst_iter_c_desc, + /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc. + /// + /// This would then indicate that the LSTM backward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// The @p weights_projection_desc together with @p + /// diff_weights_projection_desc may point to a zero memory descriptor. + /// This would then indicate that the LSTM doesn't have recurrent + /// projection layer. + /// + /// @note + /// All memory descriptors can be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param src_iter_c_desc Memory descriptor for the input recurrent + /// cell state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param weights_peephole_desc Memory descriptor for the weights + /// applied to the cell states (according to the Peephole LSTM + /// formula). + /// @param weights_projection_desc Memory descriptor for the weights + /// applied to the hidden states to get the recurrent projection + /// (according to the Projection LSTM formula). + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param dst_iter_c_desc Memory descriptor for the output recurrent + /// cell state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_src_iter_c_desc Memory descriptor for the diff of + /// input recurrent cell state vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_weights_peephole_desc Memory descriptor for the diff of + /// weights applied to the cell states (according to the Peephole + /// LSTM formula). + /// @param diff_weights_projection_desc Memory descriptor for the diff + /// of weights applied to the hidden states to get the recurrent + /// projection (according to the Projection LSTM formula). + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param diff_dst_iter_c_desc Memory descriptor for the diff of + /// output recurrent cell state vector. + /// @param hint_fwd_pd Primitive descriptor for an LSTM + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &src_iter_c_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &weights_peephole_desc, + const memory::desc &weights_projection_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &dst_iter_c_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_src_iter_c_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_weights_peephole_desc, + const memory::desc &diff_weights_projection_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const memory::desc &diff_dst_iter_c_desc, + const lstm_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, &src_iter_c_desc, nullptr, + weights_layer_desc, weights_iter_desc, + &weights_peephole_desc, &weights_projection_desc, bias_desc, + dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, + diff_src_layer_desc, diff_src_iter_desc, + &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc, + diff_weights_iter_desc, &diff_weights_peephole_desc, + &diff_weights_projection_desc, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc, + &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs an LSTM (with or without peephole) primitive descriptor + /// for backward propagation using @p prop_kind, @p direction, + /// and memory descriptors. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p src_iter_c_desc, + /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc, + /// - @p weights_peephole_desc together with + /// @p diff_weights_peephole_desc + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p dst_iter_c_desc, + /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc. + /// + /// This would then indicate that the LSTM backward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param src_iter_c_desc Memory descriptor for the input recurrent + /// cell state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param weights_peephole_desc Memory descriptor for the weights + /// applied to the cell states (according to the Peephole LSTM + /// formula). + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param dst_iter_c_desc Memory descriptor for the output recurrent + /// cell state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_src_iter_c_desc Memory descriptor for the diff of + /// input recurrent cell state vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_weights_peephole_desc Memory descriptor for the diff of + /// weights applied to the cell states (according to the Peephole + /// LSTM formula). + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param diff_dst_iter_c_desc Memory descriptor for the diff of + /// output recurrent cell state vector. + /// @param hint_fwd_pd Primitive descriptor for an LSTM + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &src_iter_c_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &weights_peephole_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &dst_iter_c_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_src_iter_c_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_weights_peephole_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const memory::desc &diff_dst_iter_c_desc, + const lstm_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, &src_iter_c_desc, nullptr, + weights_layer_desc, weights_iter_desc, + &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc, + dst_iter_desc, &dst_iter_c_desc, diff_src_layer_desc, + diff_src_iter_desc, &diff_src_iter_c_desc, nullptr, + diff_weights_layer_desc, diff_weights_iter_desc, + &diff_weights_peephole_desc, nullptr, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc, + &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs an LSTM primitive descriptor for backward propagation + /// using @p prop_kind, @p direction, and memory descriptors. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p src_iter_c_desc, + /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc, + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p dst_iter_c_desc, + /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc. + /// + /// This would then indicate that the LSTM backward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param src_iter_c_desc Memory descriptor for the input recurrent + /// cell state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param dst_iter_c_desc Memory descriptor for the output recurrent + /// cell state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_src_iter_c_desc Memory descriptor for the diff of + /// input recurrent cell state vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param diff_dst_iter_c_desc Memory descriptor for the diff of + /// output recurrent cell state vector. + /// @param hint_fwd_pd Primitive descriptor for a convolution + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &src_iter_c_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &dst_iter_c_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_src_iter_c_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const memory::desc &diff_dst_iter_c_desc, + const lstm_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, &src_iter_c_desc, nullptr, + weights_layer_desc, weights_iter_desc, nullptr, nullptr, + bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc, + diff_src_layer_desc, diff_src_iter_desc, + &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc, + diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc, + &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an LSTM backward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for an LSTM backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, + dnnl::algorithm::vanilla_lstm) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_c_desc() const { + return rnn_base::src_iter_c_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const + memory::desc weights_peephole_desc() const { + return rnn_base::weights_peephole_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const + memory::desc weights_projection_desc() const { + return rnn_base::weights_projection_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc dst_iter_c_desc() const { + return rnn_base::dst_iter_c_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const + memory::desc diff_src_layer_desc() const { + return rnn_base::diff_src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const + memory::desc diff_src_iter_desc() const { + return rnn_base::diff_src_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const + memory::desc diff_src_iter_c_desc() const { + return rnn_base::diff_src_iter_c_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const + memory::desc diff_weights_layer_desc() const { + return rnn_base::diff_weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const + memory::desc diff_weights_iter_desc() const { + return rnn_base::diff_weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const + memory::desc diff_weights_peephole_desc() const { + return rnn_base::diff_weights_peephole_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const + memory::desc diff_weights_projection_desc() const { + return rnn_base::diff_weights_projection_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return rnn_base::diff_bias_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const + memory::desc diff_dst_layer_desc() const { + return rnn_base::diff_dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const + memory::desc diff_dst_iter_desc() const { + return rnn_base::diff_dst_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const + memory::desc diff_dst_iter_c_desc() const { + return rnn_base::diff_dst_iter_c_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + lstm_backward() = default; + + /// Constructs an LSTM backward propagation primitive. + /// @param pd Primitive descriptor for an LSTM backward propagation + /// primitive. + lstm_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LSTM backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LSTM backward propagation + /// primitive. + /// @param cache_blob Cache blob. + lstm_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// GRU forward propagation primitive. +struct gru_forward : public primitive { + /// Primitive descriptor for a GRU forward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a GRU forward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc. + /// + /// This would then indicate that the GRU forward propagation primitive + /// should not use them and should default to zero values instead. + /// + /// @note + /// All memory descriptors except @p src_iter_desc may be + /// initialized with an #dnnl::memory::format_tag::any value of @p + /// format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, nullptr, nullptr, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, + 0.0f, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a GRU forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a GRU forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference, + dnnl::algorithm::vanilla_gru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + gru_forward() = default; + + /// Constructs a GRU forward propagation primitive. + /// @param pd Primitive descriptor for a GRU forward propagation + /// primitive. + gru_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a GRU forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a GRU forward propagation + /// primitive. + /// @param cache_blob Cache blob. + gru_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// GRU backward propagation primitive. +struct gru_backward : public primitive { + /// Primitive descriptor for a GRU backward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a GRU backward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p diff_src_iter_desc, + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p diff_dst_iter_desc. + /// + /// This would then indicate that the GRU backward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param hint_fwd_pd Primitive descriptor for a GRU + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const gru_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, nullptr, nullptr, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, + diff_src_iter_desc, nullptr, nullptr, + diff_weights_layer_desc, diff_weights_iter_desc, nullptr, + nullptr, diff_bias_desc, diff_dst_layer_desc, + diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a GRU backward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a GRU backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, + dnnl::algorithm::vanilla_gru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const + memory::desc diff_src_layer_desc() const { + return rnn_base::diff_src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const + memory::desc diff_src_iter_desc() const { + return rnn_base::diff_src_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const + memory::desc diff_weights_layer_desc() const { + return rnn_base::diff_weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const + memory::desc diff_weights_iter_desc() const { + return rnn_base::diff_weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return rnn_base::diff_bias_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const + memory::desc diff_dst_layer_desc() const { + return rnn_base::diff_dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const + memory::desc diff_dst_iter_desc() const { + return rnn_base::diff_dst_iter_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + gru_backward() = default; + + /// Constructs a GRU backward propagation primitive. + /// @param pd Primitive descriptor for a GRU backward propagation + /// primitive. + gru_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a GRU backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a GRU backward propagation + /// primitive. + /// @param cache_blob Cache blob. + gru_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// LBR GRU forward propagation primitive. +struct lbr_gru_forward : public primitive { + /// Primitive descriptor for an LBR GRU forward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for LBR GRU forward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc. + /// + /// This would then indicate that the LBR GRU forward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors except @p src_iter_desc may be + /// initialized with an #dnnl::memory::format_tag::any value of @p + /// format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind, + algorithm::undef, direction, src_layer_desc, src_iter_desc, + nullptr, nullptr, weights_layer_desc, weights_iter_desc, + nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc, + nullptr, rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a LBR GRU forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a LBR GRU forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference, + dnnl::algorithm::lbr_gru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + lbr_gru_forward() = default; + + /// Constructs an LBR GRU forward propagation primitive. + /// @param pd Primitive descriptor for an LBR GRU forward propagation + /// primitive. + lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LBR GRU forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LBR GRU forward propagation + /// primitive. + /// @param cache_blob Cache blob. + lbr_gru_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// LBR GRU backward propagation primitive. +struct lbr_gru_backward : public primitive { + /// Primitive descriptor for an LBR GRU backward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for LBR GRU backward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p diff_src_iter_desc, + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p diff_dst_iter_desc. + /// + /// This would then indicate that the LBR GRU backward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param hint_fwd_pd Primitive descriptor for an LBR GRU + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const lbr_gru_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind, + algorithm::undef, direction, src_layer_desc, src_iter_desc, + nullptr, nullptr, weights_layer_desc, weights_iter_desc, + nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc, + nullptr, diff_src_layer_desc, diff_src_iter_desc, nullptr, + nullptr, diff_weights_layer_desc, diff_weights_iter_desc, + nullptr, nullptr, diff_bias_desc, diff_dst_layer_desc, + diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a LBR GRU backward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a LBR GRU backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base( + pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const + memory::desc diff_src_layer_desc() const { + return rnn_base::diff_src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const + memory::desc diff_src_iter_desc() const { + return rnn_base::diff_src_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const + memory::desc diff_weights_layer_desc() const { + return rnn_base::diff_weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const + memory::desc diff_weights_iter_desc() const { + return rnn_base::diff_weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return rnn_base::diff_bias_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const + memory::desc diff_dst_layer_desc() const { + return rnn_base::diff_dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const + memory::desc diff_dst_iter_desc() const { + return rnn_base::diff_dst_iter_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + lbr_gru_backward() = default; + + /// Constructs an LBR GRU backward propagation primitive. + /// @param pd Primitive descriptor for an LBR GRU backward propagation + /// primitive. + lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LBR GRU backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LBR GRU backward propagation + /// primitive. + /// @param cache_blob Cache blob. + lbr_gru_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// AUGRU forward propagation primitive. +struct augru_forward : public primitive { + /// Primitive descriptor for an AUGRU forward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an AUGRU forward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc. + /// + /// This would then indicate that the AUGRU forward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors except @p src_iter_desc may be + /// initialized with an #dnnl::memory::format_tag::any value of @p + /// format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param attention_desc Memory descriptor for the attention vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &attention_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, nullptr, &attention_desc, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, + 0.0f, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an AUGRU forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for an AUGRU forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference, + dnnl::algorithm::vanilla_augru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const + memory::desc attention_desc() const { + return rnn_base::augru_attention_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + augru_forward() = default; + + /// Constructs an AUGRU forward propagation primitive. + /// @param pd Primitive descriptor for an AUGRU forward propagation + /// primitive. + augru_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an AUGRU forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an AUGRU forward propagation + /// primitive. + /// @param cache_blob Cache blob. + augru_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// AUGRU backward propagation primitive. +struct augru_backward : public primitive { + /// Descriptor for an AUGRU backward propagation primitive. + /// Primitive descriptor for an AUGRU backward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an AUGRU backward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p diff_src_iter_desc, + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p diff_dst_iter_desc. + /// + /// This would then indicate that the AUGRU backward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param attention_desc Memory descriptor for the attention vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_attention_desc Memory descriptor for the diff of + /// attention vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param hint_fwd_pd Primitive descriptor for an AUGRU + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &attention_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_attention_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const augru_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru, + aprop_kind, algorithm::undef, direction, src_layer_desc, + src_iter_desc, nullptr, &attention_desc, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, + diff_src_iter_desc, nullptr, &diff_attention_desc, + diff_weights_layer_desc, diff_weights_iter_desc, nullptr, + nullptr, diff_bias_desc, diff_dst_layer_desc, + diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an AUGRU backward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for an AUGRU backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, + dnnl::algorithm::vanilla_augru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const + memory::desc attention_desc() const { + return rnn_base::augru_attention_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const + memory::desc diff_src_layer_desc() const { + return rnn_base::diff_src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const + memory::desc diff_src_iter_desc() const { + return rnn_base::diff_src_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const + memory::desc diff_attention_desc() const { + return rnn_base::diff_augru_attention_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const + memory::desc diff_weights_layer_desc() const { + return rnn_base::diff_weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const + memory::desc diff_weights_iter_desc() const { + return rnn_base::diff_weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return rnn_base::diff_bias_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const + memory::desc diff_dst_layer_desc() const { + return rnn_base::diff_dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const + memory::desc diff_dst_iter_desc() const { + return rnn_base::diff_dst_iter_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + augru_backward() = default; + + /// Constructs an AUGRU backward propagation primitive. + /// @param pd Primitive descriptor for an AUGRU backward propagation + /// primitive. + augru_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an AUGRU backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an AUGRU backward propagation + /// primitive. + /// @param cache_blob Cache blob. + augru_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// LBR AUGRU forward propagation primitive. +struct lbr_augru_forward : public primitive { + /// Descriptor for an LBR AUGRU forward propagation primitive. + + /// Primitive descriptor for an LBR AUGRU forward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for LBR AUGRU forward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc, + /// - @p bias_desc, + /// - @p dst_iter_desc. + /// + /// This would then indicate that the LBR AUGRU forward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors except @p src_iter_desc may be + /// initialized with an #dnnl::memory::format_tag::any value of @p + /// format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param attention_desc Memory descriptor for the attention vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &attention_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind, + algorithm::undef, direction, src_layer_desc, src_iter_desc, + nullptr, &attention_desc, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef, + 0.0f, 0.0f, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an LBR AUGRU forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for an LBR AUGRU forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference, + dnnl::algorithm::lbr_augru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const + memory::desc attention_desc() const { + return rnn_base::augru_attention_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + lbr_augru_forward() = default; + + /// Constructs an LBR AUGRU forward propagation primitive. + /// @param pd Primitive descriptor for an LBR AUGRU forward propagation + /// primitive. + lbr_augru_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LBR AUGRU forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LBR AUGRU forward propagation + /// primitive. + /// @param cache_blob Cache blob. + lbr_augru_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// LBR AUGRU backward propagation primitive. +struct lbr_augru_backward : public primitive { + /// Primitive descriptor for an LBR AUGRU backward propagation primitive. + struct primitive_desc : public rnn_primitive_desc_base { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for LBR AUGRU backward propagation + /// primitive. + /// + /// The following arguments may point to a zero memory descriptor: + /// - @p src_iter_desc together with @p diff_src_iter_desc, + /// - @p bias_desc together with @p diff_bias_desc, + /// - @p dst_iter_desc together with @p diff_dst_iter_desc. + /// + /// This would then indicate that the LBR AUGRU backward propagation + /// primitive should not use them and should default to zero values + /// instead. + /// + /// @note + /// All memory descriptors may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Must be + /// #dnnl::prop_kind::backward. + /// @param direction RNN direction. See @ref dnnl::rnn_direction for + /// more info. + /// @param src_layer_desc Memory descriptor for the input vector. + /// @param src_iter_desc Memory descriptor for the input recurrent + /// hidden state vector. + /// @param attention_desc Memory descriptor for the attention vector. + /// @param weights_layer_desc Memory descriptor for the weights + /// applied to the layer input. + /// @param weights_iter_desc Memory descriptor for the weights applied + /// to the recurrent input. + /// @param bias_desc Bias memory descriptor. + /// @param dst_layer_desc Memory descriptor for the output vector. + /// @param dst_iter_desc Memory descriptor for the output recurrent + /// hidden state vector. + /// @param diff_src_layer_desc Memory descriptor for the diff of input + /// vector. + /// @param diff_src_iter_desc Memory descriptor for the diff of input + /// recurrent hidden state vector. + /// @param diff_attention_desc Memory descriptor for the diff of + /// attention vector. + /// @param diff_weights_layer_desc Memory descriptor for the diff of + /// weights applied to the layer input. + /// @param diff_weights_iter_desc Memory descriptor for the diff of + /// weights applied to the recurrent input. + /// @param diff_bias_desc Diff bias memory descriptor. + /// @param diff_dst_layer_desc Memory descriptor for the diff of + /// output vector. + /// @param diff_dst_iter_desc Memory descriptor for the diff of output + /// recurrent hidden state vector. + /// @param hint_fwd_pd Primitive descriptor for an LBR AUGRU + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + rnn_direction direction, const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &attention_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_attention_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc, + const lbr_augru_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind, + algorithm::undef, direction, src_layer_desc, src_iter_desc, + nullptr, &attention_desc, weights_layer_desc, + weights_iter_desc, nullptr, nullptr, bias_desc, + dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc, + diff_src_iter_desc, nullptr, &diff_attention_desc, + diff_weights_layer_desc, diff_weights_iter_desc, nullptr, + nullptr, diff_bias_desc, diff_dst_layer_desc, + diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f, + hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for an LBR AUGRU backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for an LBR AUGRU backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward, + dnnl::algorithm::lbr_augru) {} + + /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const + memory::desc src_layer_desc() const { + return rnn_base::src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const + memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const + memory::desc attention_desc() const { + return rnn_base::augru_attention_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const + memory::desc weights_layer_desc() const { + return rnn_base::weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const + memory::desc weights_iter_desc() const { + return rnn_base::weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const + memory::desc bias_desc() const { return rnn_base::bias_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const + memory::desc dst_layer_desc() const { + return rnn_base::dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const + memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); } + + /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { + return rnn_base::workspace_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const + memory::desc diff_src_layer_desc() const { + return rnn_base::diff_src_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const + memory::desc diff_src_iter_desc() const { + return rnn_base::diff_src_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const + memory::desc diff_attention_desc() const { + return rnn_base::diff_augru_attention_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const + memory::desc diff_weights_layer_desc() const { + return rnn_base::diff_weights_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const + memory::desc diff_weights_iter_desc() const { + return rnn_base::diff_weights_iter_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const + memory::desc diff_bias_desc() const { + return rnn_base::diff_bias_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const + memory::desc diff_dst_layer_desc() const { + return rnn_base::diff_dst_layer_desc(); + } + + /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const + memory::desc diff_dst_iter_desc() const { + return rnn_base::diff_dst_iter_desc(); + } + + /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const + algorithm get_cell_kind() const { return base::get_cell_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_direction()const + rnn_direction get_direction() const { return base::get_direction(); } + }; + + /// Default constructor. Produces an empty object. + lbr_augru_backward() = default; + + /// Constructs an LBR AUGRU backward propagation primitive. + /// @param pd Primitive descriptor for an LBR AUGRU backward propagation + /// primitive. + lbr_augru_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an LBR AUGRU backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for an LBR AUGRU backward propagation + /// primitive. + /// @param cache_blob Cache blob. + lbr_augru_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_rnn + +/// @addtogroup dnnl_api_shuffle Shuffle +/// +/// A primitive to shuffle tensor data along an axis. +/// +/// @sa @ref dev_guide_shuffle in developer guide +/// +/// @{ + +/// Shuffle forward propagation primitive. +struct shuffle_forward : public primitive { + /// Primitive descriptor for a shuffle forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a shuffle forward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param axis The axis along which the data is shuffled. + /// @param group_size Shuffle group size. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &dst_desc, + int axis, int group_size, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_shuffle_forward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + src_desc.get(), dst_desc.get(), axis, group_size, + attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the shuffle forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a shuffle forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a shuffle forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_axis()const + int get_axis() const { return base::get_axis(); } + + /// @copydoc dnnl::primitive_desc_base::get_group_size()const + memory::dim get_group_size() const { return base::get_group_size(); } + }; + + /// Default constructor. Produces an empty object. + shuffle_forward() = default; + + /// Constructs a shuffle forward propagation primitive. + /// @param pd Primitive descriptor for a shuffle forward propagation + /// primitive. + shuffle_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a shuffle forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a shuffle forward propagation + /// primitive. + /// @param cache_blob Cache blob. + shuffle_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Shuffle backward propagation primitive. +struct shuffle_backward : public primitive { + /// Primitive descriptor for a shuffle backward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a shuffle backward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param axis The axis along which the data is shuffled. + /// @param group_size Shuffle group size. + /// @param hint_fwd_pd Primitive descriptor for a shuffle forward + /// propagation primitive. It is used as a hint for deciding which + /// memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, int axis, int group_size, + const shuffle_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_shuffle_backward_primitive_desc_create( + &pd, aengine.get(), diff_src_desc.get(), + diff_dst_desc.get(), axis, group_size, hint_fwd_pd.get(), + attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the shuffle backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a shuffle backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a shuffle backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_axis()const + int get_axis() const { return base::get_axis(); } + + /// @copydoc dnnl::primitive_desc_base::get_group_size()const + memory::dim get_group_size() const { return base::get_group_size(); } + }; + + /// Default constructor. Produces an empty object. + shuffle_backward() = default; + + /// Constructs a shuffle backward propagation primitive. + /// @param pd Primitive descriptor for a shuffle backward propagation + /// primitive. + shuffle_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a shuffle backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a shuffle backward propagation + /// primitive. + /// @param cache_blob Cache blob. + shuffle_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_shuffle + +/// @addtogroup dnnl_api_binary Binary +/// +/// A primitive to perform tensor operations over two tensors. +/// +/// @sa @ref dev_guide_binary in developer guide +/// +/// @{ + +/// Elementwise binary operator primitive. +struct binary : public primitive { + /// Primitive descriptor for an elementwise binary operator primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for an elementwise binary operator + /// primitive. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Elementwise binary algorithm. + /// @param src0 Memory descriptor for source tensor #0. + /// @param src1 Memory descriptor for source tensor #1. + /// @param dst Memory descriptor for destination tensor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src0, const memory::desc &src1, + const memory::desc &dst, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_binary_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(), + src1.get(), dst.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the binary operation primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for an elementwise binary operator + /// primitive with support of ternary operators. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Elementwise binary algorithm. + /// @param src0 Memory descriptor for source tensor #0. + /// @param src1 Memory descriptor for source tensor #1. + /// @param src2 Memory descriptor for source tensor #2 for ternary + /// operations. Might be empty. + /// @param dst Memory descriptor for destination tensor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src0, const memory::desc &src1, + const memory::desc &src2, const memory::desc &dst, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_binary_primitive_desc_create_v2(&pd, + aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(), + src1.get(), src2.get(), dst.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the binary v2 operation primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a binary primitive from a C + /// API primitive descriptor that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a binary primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc(int)const + memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); } + + /// Returns the memory descriptor for source #0. + memory::desc src0_desc() const { return base::src_desc(0); } + + /// Returns the memory descriptor for source #1. + memory::desc src1_desc() const { return base::src_desc(1); } + + /// Returns the memory descriptor for source #2. + memory::desc src2_desc() const { return base::src_desc(2); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + }; + + /// Default constructor. Produces an empty object. + binary() = default; + + /// Constructs an elementwise binary operation primitive. + /// @param pd Primitive descriptor for an elementwise binary operation + /// primitive. + binary(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs an elementwise binary operation primitive from a cache blob. + /// @param pd Primitive descriptor for an elementwise binary operation + /// primitive. + /// @param cache_blob Cache blob. + binary(const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_binary + +/// @addtogroup dnnl_api_matmul Matrix Multiplication +/// +/// A primitive to perform matrix-matrix multiplication. The batched mode +/// is supported with 3D tensors. +/// +/// @sa @ref dev_guide_matmul in developer guide +/// +/// +/// @{ + +/// Matrix multiplication (matmul) primitive. +struct matmul : public primitive { + /// Primitive descriptor for a matmul primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a matmul primitive + /// without bias. + /// + /// @param aengine Engine to use. + /// @param src_desc Memory descriptor for source (matrix A). + /// @param weights_desc Memory descriptor for weights (matrix B). + /// @param dst_desc Memory descriptor for destination (matrix C). + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, src_desc, weights_desc, nullptr, dst_desc, + attr, allow_empty) {} + + /// Constructs a primitive descriptor for a matmul primitive with bias. + /// + /// @param aengine Engine to use. + /// @param src_desc Memory descriptor for source (matrix A). + /// @param weights_desc Memory descriptor for weights (matrix B). + /// @param dst_desc Memory descriptor for destination (matrix C). + /// @param bias_desc Memory descriptor for bias. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc &bias_desc, + const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, src_desc, weights_desc, &bias_desc, + dst_desc, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a matmul primitive from a C + /// API primitive descriptor that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a matmul primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return query_md(query::src_md, 0); } + + /// @copydoc dnnl::primitive_desc_base::weights_desc()const + memory::desc weights_desc() const { + return query_md(query::weights_md, 0); + } + + /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const + memory::desc bias_desc() const { + return query_md(query::weights_md, 1); + } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return query_md(query::dst_md, 0); } + + private: + primitive_desc(const engine &aengine, const memory::desc &src_desc, + const memory::desc &weights_desc, const memory::desc *bias_desc, + const memory::desc &dst_desc, const primitive_attr &attr, + bool allow_empty) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_matmul_primitive_desc_create(&pd, + aengine.get(), src_desc.get(), weights_desc.get(), + optional_arg(bias_desc), dst_desc.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the matmul primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + matmul() = default; + + /// Constructs a matmul primitive. + /// @param pd Primitive descriptor for a matmul primitive. + matmul(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a matmul primitive from a cache blob. + /// @param pd Primitive descriptor for a matmul primitive. + /// @param cache_blob Cache blob. + matmul(const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_matmul + +/// @addtogroup dnnl_api_resampling Resampling +/// +/// A primitive to compute resampling operation on 1D, 2D or 3D data tensor +/// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation +/// method. +/// +/// @sa @ref dev_guide_resampling in developer guide +/// +/// @{ + +/// Resampling forward propagation. +struct resampling_forward : public primitive { + /// Primitive descriptor for a resampling forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a resampling forward + /// propagation primitive using source and destination memory + /// descriptors. + /// + /// @note + /// Destination memory descriptor may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm resampling algorithm kind: either + /// #dnnl::algorithm::resampling_nearest, or + /// #dnnl::algorithm::resampling_linear + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, nullptr, src_desc, + &dst_desc, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a resampling forward + /// propagation primitive using source memory descriptor and + /// factors. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm resampling algorithm kind: either + /// #dnnl::algorithm::resampling_nearest, or + /// #dnnl::algorithm::resampling_linear + /// @param factors Vector of scaling factors for spatial dimension. + /// @param src_desc Source memory descriptor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const std::vector &factors, + const memory::desc &src_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, &factors, + src_desc, nullptr, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a resampling forward + /// propagation primitive. + /// + /// @note + /// The destination memory descriptor may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm resampling algorithm kind: either + /// #dnnl::algorithm::resampling_nearest, or + /// #dnnl::algorithm::resampling_linear + /// @param factors Vector of scaling factors for spatial dimension. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const std::vector &factors, + const memory::desc &src_desc, const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aprop_kind, aalgorithm, &factors, + src_desc, &dst_desc, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a resampling forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a resampling forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + private: + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const std::vector *factors, + const memory::desc &src_desc, const memory::desc *dst_desc, + const primitive_attr &attr, bool allow_empty) { + + if (factors) + memory::validate_dims(*factors, src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_resampling_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), optional_arg(factors), + src_desc.get(), optional_arg(dst_desc), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the resampling forward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + resampling_forward() = default; + + /// Constructs a resampling forward propagation primitive. + /// @param pd Primitive descriptor for a resampling forward propagation + /// primitive. + resampling_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a resampling forward propagation primitive from a cache + /// blob. + /// @param pd Primitive descriptor for a resampling forward propagation + /// primitive. + /// @param cache_blob Cache blob. + resampling_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Resampling backward propagation primitive. +struct resampling_backward : public primitive { + /// Primitive descriptor for resampling backward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a resampling backward + /// propagation primitive using source and destination memory + /// descriptors. + /// + /// @param aengine Engine to use. + /// @param aalgorithm resampling algorithm kind: either + /// #dnnl::algorithm::resampling_nearest, or + /// #dnnl::algorithm::resampling_linear + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param hint_fwd_pd Primitive descriptor for a resampling + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const resampling_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, nullptr, diff_src_desc, + diff_dst_desc, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for resampling backward + /// propagation primitive. + /// + /// @param aengine Engine to use. + /// @param aalgorithm resampling algorithm kind: either + /// #dnnl::algorithm::resampling_nearest, or + /// #dnnl::algorithm::resampling_linear + /// @param factors Vector of scaling factors for spatial dimension. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param hint_fwd_pd Primitive descriptor for a resampling + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const std::vector &factors, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const resampling_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) + : primitive_desc(aengine, aalgorithm, &factors, diff_src_desc, + diff_dst_desc, hint_fwd_pd, attr, allow_empty) {} + + /// Constructs a primitive descriptor for a resampling backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a resampling backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + private: + primitive_desc(const engine &aengine, algorithm aalgorithm, + const std::vector *factors, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const resampling_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr, bool allow_empty) { + + if (factors) + memory::validate_dims(*factors, diff_src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status + = dnnl_resampling_backward_primitive_desc_create(&pd, + aengine.get(), convert_to_c(aalgorithm), + optional_arg(factors), diff_src_desc.get(), + diff_dst_desc.get(), hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the resampling backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + resampling_backward() = default; + + /// Constructs a resampling backward propagation primitive. + /// @param pd Primitive descriptor for a resampling backward propagation + /// primitive. + resampling_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a resampling backward propagation primitive from a cache + /// blob. + /// @param pd Primitive descriptor for a resampling backward propagation + /// primitive. + /// @param cache_blob Cache blob. + resampling_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_resampling + +/// @addtogroup dnnl_api_pooling Pooling +/// +/// A primitive to perform max or average pooling with dilation. +/// +/// @sa @ref dev_guide_pooling in developer guide +/// +/// @{ + +/// Pooling forward propagation primitive. +struct pooling_forward : public primitive { + /// Primitive descriptor for a pooling forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for pooling forward propagation + /// primitive. + /// + /// Arrays @p strides, @p kernel, @p dilation, @p padding_l + /// and @p padding_r contain values for spatial dimensions only and + /// hence must have the same number of elements as there are spatial + /// dimensions. The order of values is the same as in the tensor: + /// depth (for 3D tensors), height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param aalgorithm Pooling algorithm kind: either + /// #dnnl::algorithm::pooling_max, + /// #dnnl::algorithm::pooling_avg_include_padding, + /// or #dnnl::algorithm::pooling_avg_exclude_padding. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param strides Vector of strides for spatial dimension. + /// @param kernel Vector of kernel spatial dimensions. + /// @param dilation Array of dilations for spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + algorithm aalgorithm, const memory::desc &src_desc, + const memory::desc &dst_desc, const memory::dims &strides, + const memory::dims &kernel, const memory::dims &dilation, + const memory::dims &padding_l, const memory::dims &padding_r, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + memory::validate_dims(strides, src_desc.get_ndims() - 2); + memory::validate_dims(kernel, src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, src_desc.get_ndims() - 2); + memory::validate_dims(dilation, src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_pooling_forward_primitive_desc_create( + &pd, aengine.get(), dnnl::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(), + &strides[0], &kernel[0], &dilation[0], &padding_l[0], + &padding_r[0], attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a descriptor for a pooling forward " + "propagation primitive"); + reset(pd); + } + + /// Constructs a primitive descriptor for a pooling forward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a pooling forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_kernel()const + memory::dims get_kernel() const { return base::get_kernel(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + }; + + /// Default constructor. Produces an empty object. + pooling_forward() = default; + + /// Constructs a pooling forward propagation primitive. + /// + /// @param pd Primitive descriptor for a pooling forward propagation + /// primitive. + pooling_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a pooling forward propagation primitive from a cache blob. + /// + /// @param pd Primitive descriptor for a pooling forward propagation + /// primitive. + /// @param cache_blob Cache blob. + pooling_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// Pooling backward propagation primitive. +struct pooling_backward : public primitive { + /// Primitive descriptor for a pooling backward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a pooling backward propagation + /// primitive. + /// + /// Arrays @p strides, @p kernel, @p dilation, @p padding_l + /// and @p padding_r contain values for spatial dimensions only and + /// hence must have the same number of elements as there are spatial + /// dimensions. The order of values is the same as in the tensor: + /// depth (for 3D tensors), height (for 3D and 2D tensors), and width. + /// + /// @param aengine Engine to use. + /// @param aalgorithm Pooling algorithm kind: either + /// #dnnl::algorithm::pooling_max, + /// #dnnl::algorithm::pooling_avg_include_padding, + /// or #dnnl::algorithm::pooling_avg_exclude_padding. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param strides Vector of strides for spatial dimension. + /// @param kernel Vector of kernel spatial dimensions. + /// @param dilation Array of dilations for spatial dimension. + /// @param padding_l Vector of padding values for low indices for each + /// spatial dimension `([[front,] top,] left)`. + /// @param padding_r Vector of padding values for high indices for + /// each spatial dimension `([[back,] bottom,] right)`. + /// @param hint_fwd_pd Primitive descriptor for a pooling + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, const memory::dims &strides, + const memory::dims &kernel, const memory::dims &dilation, + const memory::dims &padding_l, const memory::dims &padding_r, + const pooling_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + memory::validate_dims(strides, diff_src_desc.get_ndims() - 2); + memory::validate_dims(kernel, diff_src_desc.get_ndims() - 2); + memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2); + memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2); + memory::validate_dims(dilation, diff_src_desc.get_ndims() - 2); + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_pooling_backward_primitive_desc_create( + &pd, aengine.get(), convert_to_c(aalgorithm), + diff_src_desc.get(), diff_dst_desc.get(), &strides[0], + &kernel[0], &dilation[0], &padding_l[0], &padding_r[0], + hint_fwd_pd.get(), attr.get()); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a descriptor for a pooling backward " + "propagation primitive"); + reset(pd); + } + + /// Constructs a primitive descriptor for a pooling backward propagation + /// primitive from a C API primitive descriptor that must have a + /// matching kind. + /// + /// @param pd C API primitive descriptor for a pooling backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling, + dnnl::prop_kind::backward_data) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::workspace_desc()const + memory::desc workspace_desc() const { return base::workspace_desc(); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + + /// @copydoc dnnl::primitive_desc_base::get_strides()const + memory::dims get_strides() const { return base::get_strides(); } + + /// @copydoc dnnl::primitive_desc_base::get_kernel()const + memory::dims get_kernel() const { return base::get_kernel(); } + + /// @copydoc dnnl::primitive_desc_base::get_dilations()const + memory::dims get_dilations() const { return base::get_dilations(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_l()const + memory::dims get_padding_l() const { return base::get_padding_l(); } + + /// @copydoc dnnl::primitive_desc_base::get_padding_r()const + memory::dims get_padding_r() const { return base::get_padding_r(); } + }; + + /// Default constructor. Produces an empty object. + pooling_backward() = default; + + /// Constructs a pooling backward propagation primitive. + /// + /// @param pd Primitive descriptor for a pooling backward propagation + /// primitive. + pooling_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a pooling backward propagation primitive from a cache blob. + /// + /// @param pd Primitive descriptor for a pooling backward propagation + /// primitive. + /// @param cache_blob Cache blob. + pooling_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_pooling + +/// @addtogroup dnnl_api_prelu PReLU +/// +/// PReLU primitive +/// A primitive to perform PReLU (leaky ReLU with trainable alpha parameter) +/// +/// @sa @ref dev_guide_prelu in developer guide +/// +/// @{ + +/// PReLU forward propagation primitive. +struct prelu_forward : public primitive { + /// Primitive descriptor for a PReLU forward propagation primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a PReLU forward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param aprop_kind Propagation kind. Possible values are + /// #dnnl::prop_kind::forward_training, and + /// #dnnl::prop_kind::forward_inference. + /// @param src_desc Source memory descriptor. + /// @param weight_desc Alpha parameters memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, prop_kind aprop_kind, + const memory::desc &src_desc, const memory::desc &weight_desc, + const memory::desc &dst_desc, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_prelu_forward_primitive_desc_create(&pd, + aengine.get(), dnnl::convert_to_c(aprop_kind), + src_desc.get(), weight_desc.get(), dst_desc.get(), + attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the prelu forward propagation primitive. Run workload " + "with environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a prelu forward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a prelu forward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu, + dnnl::prop_kind::forward_training, + dnnl::prop_kind::forward_inference) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + }; + + /// Default constructor. Produces an empty object. + prelu_forward() = default; + + /// Constructs a prelu forward propagation primitive. + /// @param pd Primitive descriptor for a prelu forward propagation + /// primitive. + prelu_forward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a prelu forward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a prelu forward propagation + /// primitive. + /// @param cache_blob Cache blob. + prelu_forward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// PReLU backward propagation primitive. +struct prelu_backward : public primitive { + /// Primitive descriptor for prelu backward propagation. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a descriptor for a PReLU backward propagation + /// primitive. + /// + /// @param aengine Engine to use. + /// @param src_desc Source memory descriptor. + /// @param weight_desc Alpha parameters memory descriptor. + /// @param diff_src_desc Diff source memory descriptor. + /// @param diff_weights_desc Diff alpha parameters memory descriptor. + /// @param diff_dst_desc Diff destination memory descriptor. + /// @param hint_fwd_pd Primitive descriptor for a PReLU + /// forward propagation primitive. It is used as a hint for + /// deciding which memory format to use. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, const memory::desc &src_desc, + const memory::desc &weight_desc, + const memory::desc &diff_src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const prelu_forward::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_prelu_backward_primitive_desc_create( + &pd, aengine.get(), src_desc.get(), weight_desc.get(), + diff_src_desc.get(), diff_weights_desc.get(), + diff_dst_desc.get(), hint_fwd_pd.get(), attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the prelu backward propagation primitive. Run " + "workload with environment variable ONEDNN_VERBOSE=all " + "to get additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a prelu backward + /// propagation primitive from a C API primitive descriptor that must + /// have a matching kind. + /// + /// @param pd C API primitive descriptor for a prelu backward + /// propagation primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu, + dnnl::prop_kind::backward) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const + memory::desc diff_src_desc() const { return base::diff_src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const + memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const + prop_kind get_prop_kind() const { return base::get_prop_kind(); } + }; + + /// Default constructor. Produces an empty object. + prelu_backward() = default; + + /// Constructs a prelu backward propagation primitive. + /// @param pd Primitive descriptor for a prelu backward propagation + /// primitive. + prelu_backward(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a prelu backward propagation primitive from a cache blob. + /// @param pd Primitive descriptor for a prelu backward propagation + /// primitive. + /// @param cache_blob Cache blob. + prelu_backward( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_prelu + +/// @addtogroup dnnl_api_reduction Reduction +/// +/// A primitive to compute reduction operation on data tensor +/// using min, max, mul, sum, mean and norm_lp operations. +/// +/// @sa @ref dev_guide_reduction in developer guide +/// +/// @{ + +/// Reduction. +struct reduction : public primitive { + /// Primitive descriptor for a reduction primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for a reduction primitive using + /// algorithm specific parameters, source and destination memory + /// descriptors. + /// + /// @note + /// Destination memory descriptor may be initialized with + /// #dnnl::memory::format_tag::any value of @p format_tag. + /// + /// @param aengine Engine to use. + /// @param aalgorithm reduction algorithm kind. Possible values: + /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum, + /// #dnnl_reduction_mul, #dnnl_reduction_mean, + /// #dnnl_reduction_norm_lp_max, #dnnl_reduction_norm_lp_sum, + /// #dnnl_reduction_norm_lp_power_p_max, + /// #dnnl_reduction_norm_lp_power_p_sum. + /// @param p algorithm specific parameter. + /// @param eps algorithm specific parameter. + /// @param src_desc Source memory descriptor. + /// @param dst_desc Destination memory descriptor. + /// @param attr Primitive attributes to use. Attributes are optional + /// and default to empty attributes. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + primitive_desc(const engine &aengine, algorithm aalgorithm, + const memory::desc &src_desc, const memory::desc &dst_desc, + float p, float eps, const primitive_attr &attr = default_attr(), + bool allow_empty = false) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = dnnl_reduction_primitive_desc_create(&pd, + aengine.get(), convert_to_c(aalgorithm), src_desc.get(), + dst_desc.get(), p, eps, attr.get()); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for " + "the reduction primitive. Run workload with " + "environment variable ONEDNN_VERBOSE=all to get " + "additional diagnostic information."); + reset(pd); + } + + /// Constructs a primitive descriptor for a reduction primitive from a C + /// API primitive descriptor that must have a matching kind. + /// + /// @param pd C API primitive descriptor for a reduction primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : dnnl::primitive_desc(pd, dnnl::primitive::kind::reduction) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::get_p()const + float get_p() const { return base::get_p(); } + + /// @copydoc dnnl::primitive_desc_base::get_epsilon()const + float get_epsilon() const { return base::get_epsilon(); } + + /// @copydoc dnnl::primitive_desc_base::get_algorithm()const + algorithm get_algorithm() const { return base::get_algorithm(); } + }; + + /// Default constructor. Produces an empty object. + reduction() = default; + + /// Constructs a reduction primitive. + /// @param pd Primitive descriptor for a reduction primitive. + reduction(const primitive_desc &pd) : primitive(pd) {} + + /// Constructs a reduction primitive from a cache blob. + /// @param pd Primitive descriptor for a reduction primitive. + /// @param cache_blob Cache blob. + reduction(const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd, cache_blob) {} +}; + +/// @} dnnl_api_reduction + +/// @} dnnl_api_primitives + +/// @addtogroup dnnl_api_service Service +/// +/// A set of functions that aid in oneDNN debugging and profiling. +/// +/// @{ + +/// @copydoc dnnl_version_t +using version_t = dnnl_version_t; + +/// Status values returned by the library functions. +enum class status { + /// @copydoc dnnl_success + success = dnnl_success, + /// @copydoc dnnl_out_of_memory + out_of_memory = dnnl_out_of_memory, + /// @copydoc dnnl_invalid_arguments + invalid_arguments = dnnl_invalid_arguments, + /// @copydoc dnnl_unimplemented + unimplemented = dnnl_unimplemented, + /// @copydoc dnnl_last_impl_reached + last_impl_reached = dnnl_last_impl_reached, + /// @copydoc dnnl_runtime_error + runtime_error = dnnl_runtime_error, + /// @copydoc dnnl_not_required + not_required = dnnl_not_required, +}; + +/// @copydoc dnnl_set_verbose() +inline status set_verbose(int level) { + return static_cast(dnnl_set_verbose(level)); +} + +/// @copydoc dnnl_version() +inline const version_t *version() { + return dnnl_version(); +} + +/// Returns the floating-point math mode that will be used by default +/// for all subsequently created primitives. +/// +/// @returns Output FP math mode. +inline fpmath_mode get_default_fpmath_mode() { + dnnl_fpmath_mode_t mode; + error::wrap_c_api(dnnl_get_default_fpmath_mode(&mode), + "could not get a default fpmath mode"); + return static_cast(mode); +} + +/// @copydoc dnnl_set_default_fpmath_mode() +inline status set_default_fpmath_mode(fpmath_mode mode) { + return static_cast( + dnnl_set_default_fpmath_mode(convert_to_c(mode))); +} + +/// @copydoc dnnl_set_jit_dump() +inline status set_jit_dump(int enable) { + return static_cast(dnnl_set_jit_dump(enable)); +} + +/// @copydoc dnnl_set_jit_profiling_flags() +inline status set_jit_profiling_flags(unsigned flags) { + return static_cast(dnnl_set_jit_profiling_flags(flags)); +} + +/// @copydoc dnnl_set_jit_profiling_jitdumpdir() +inline status set_jit_profiling_jitdumpdir(const std::string &dir) { + return static_cast(dnnl_set_jit_profiling_jitdumpdir(dir.c_str())); +} + +/// @copydoc dnnl_cpu_isa_t +enum class cpu_isa { + /// @copydoc dnnl_cpu_isa_default + isa_default = dnnl_cpu_isa_default, + /// @copydoc dnnl_cpu_isa_sse41 + sse41 = dnnl_cpu_isa_sse41, + /// @copydoc dnnl_cpu_isa_avx + avx = dnnl_cpu_isa_avx, + /// @copydoc dnnl_cpu_isa_avx2 + avx2 = dnnl_cpu_isa_avx2, + /// @copydoc dnnl_cpu_isa_avx2_vnni + avx2_vnni = dnnl_cpu_isa_avx2_vnni, + /// @copydoc dnnl_cpu_isa_avx2_vnni_2 + avx2_vnni_2 = dnnl_cpu_isa_avx2_vnni_2, + /// @copydoc dnnl_cpu_isa_avx512_core + avx512_core = dnnl_cpu_isa_avx512_core, + /// @copydoc dnnl_cpu_isa_avx512_core_vnni + avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni, + /// @copydoc dnnl_cpu_isa_avx512_core_bf16 + avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16, + /// @copydoc dnnl_cpu_isa_avx10_1_512 + avx10_1_512 = dnnl_cpu_isa_avx10_1_512, + /// @copydoc dnnl_cpu_isa_avx512_core_fp16 + avx512_core_fp16 = dnnl_cpu_isa_avx512_core_fp16, + /// @copydoc dnnl_cpu_isa_avx10_1_512_amx + avx10_1_512_amx = dnnl_cpu_isa_avx10_1_512_amx, + /// @copydoc dnnl_cpu_isa_avx512_core_amx + avx512_core_amx = dnnl_cpu_isa_avx512_core_amx, + /// @copydoc dnnl_cpu_isa_avx10_1_512_amx_fp16 + avx10_1_512_amx_fp16 = dnnl_cpu_isa_avx10_1_512_amx_fp16, + /// @copydoc dnnl_cpu_isa_avx512_core_amx_fp16 + avx512_core_amx_fp16 = dnnl_cpu_isa_avx512_core_amx_fp16, +}; + +/// @copydoc dnnl_set_max_cpu_isa() +inline status set_max_cpu_isa(cpu_isa isa) { + return static_cast( + dnnl_set_max_cpu_isa(static_cast(isa))); +} + +/// @copydoc dnnl_get_effective_cpu_isa() +inline cpu_isa get_effective_cpu_isa() { + return static_cast(dnnl_get_effective_cpu_isa()); +} + +/// @copydoc dnnl_cpu_isa_hints_t +enum class cpu_isa_hints { + /// @copydoc dnnl_cpu_isa_no_hints + no_hints = dnnl_cpu_isa_no_hints, + /// @copydoc dnnl_cpu_isa_prefer_ymm + prefer_ymm = dnnl_cpu_isa_prefer_ymm, +}; + +/// @copydoc dnnl_set_cpu_isa_hints() +inline status set_cpu_isa_hints(cpu_isa_hints isa_hints) { + return static_cast(dnnl_set_cpu_isa_hints( + static_cast(isa_hints))); +} + +/// @copydoc dnnl_get_cpu_isa_hints() +inline cpu_isa_hints get_cpu_isa_hints() { + return static_cast(dnnl_get_cpu_isa_hints()); +} + +/// @} dnnl_api_service + +#ifdef DNNL_EXPERIMENTAL_PROFILING +/// @addtogroup dnnl_api_profiling Profiling +/// @{ + +/// Profiling data kind. +enum class profiling_data_kind { + /// Undefined profiling data kind. + undef = dnnl_profiling_data_kind_undef, + /// Data kind to query an execution time in nanoseconds. + time = dnnl_profiling_data_kind_time, +}; + +/// Resets a profiler's state. +/// +/// @param stream Stream associated with the profiler. +inline void reset_profiling(stream &stream) { + error::wrap_c_api( + dnnl_reset_profiling(stream.get()), "could not reset profiling"); +} + +/// Returns requested profiling data. The profiling data accumulates for each +/// primitive execution. The size of the vector will be equal to the number +/// of executions since the last `dnnl::reset_profiling` call. +/// +/// The profiling data can be reset by calling #dnnl::reset_profiling. +/// +/// @note +/// It is required to wait for all submitted primitives to complete +/// using #dnnl::stream::wait prior to querying profiling data. +/// +/// @param stream Stream that was used for executing a primitive that +/// is being profiled. +/// @param data_kind Profiling data kind to query. +/// +/// @returns A vector with the requested profiling data. +inline std::vector get_profiling_data( + stream &stream, profiling_data_kind data_kind) { + int num_entries = 0; + error::wrap_c_api( + dnnl_query_profiling_data(stream.get(), + static_cast(data_kind), + &num_entries, nullptr), + "could not get number of entries for profiling data"); + + if (num_entries == 0) return {}; + + std::vector data(num_entries); + error::wrap_c_api( + dnnl_query_profiling_data(stream.get(), + static_cast(data_kind), + &num_entries, data.data()), + "could not get profiling data"); + return data; +} + +/// @} dnnl_api_profiling +#endif + +/// @addtogroup dnnl_api_primitive_cache Primitive Cache +/// +/// A set of functions that provide primitive cache control. +/// +/// @{ + +/// Returns the number of primitives that can be held in the primitive cache +/// at the same time. +inline int get_primitive_cache_capacity() { + int result = 0; + error::wrap_c_api(dnnl_get_primitive_cache_capacity(&result), + "could not get primitive cache capacity"); + return result; +} + +/// @copydoc dnnl_set_primitive_cache_capacity(int capacity) +inline void set_primitive_cache_capacity(int capacity) { + error::wrap_c_api(dnnl_set_primitive_cache_capacity(capacity), + "could not set primitive cache capacity"); +} + +/// @} dnnl_api_primitive_cache + +/// @addtogroup dnnl_api_blas BLAS functions +/// +/// A subset of Basic Linear Algebra (BLAS) functions that perform +/// matrix-matrix multiplication. +/// +/// @{ + +/// @copydoc dnnl_sgemm() +inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, + dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, + const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) { + return static_cast(dnnl_sgemm( + transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc)); +} + +/// @copydoc dnnl_gemm_u8s8s32() +inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, + dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, + dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, + float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) { + return static_cast(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N, + K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co)); +} + +/// @copydoc dnnl_gemm_s8s8s32() +inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, + dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, + dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, + float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) { + return static_cast(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N, + K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co)); +} + +/// @} dnnl_api_blas + +// implementation section + +/// @cond DO_NOT_DOCUMENT_THIS +inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) { + dnnl_primitive_t result; + error::wrap_c_api(dnnl_primitive_create(&result, c_pd), + "could not create a primitive"); + reset(result); +} + +inline primitive::primitive(const_dnnl_primitive_desc_t c_pd, + const std::vector &cache_blob) { + dnnl_primitive_t result; + size_t size = cache_blob.size(); + const uint8_t *cache_blob_data = cache_blob.data(); + error::wrap_c_api(dnnl_primitive_create_from_cache_blob( + &result, c_pd, size, cache_blob_data), + "could not create a primitive from a cache blob"); + reset(result); +} + +inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {} +inline primitive::primitive( + const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd.get(), cache_blob) {} + +inline void primitive::execute(const stream &astream, + const std::unordered_map &args) const { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a : args) + c_args.push_back({a.first, a.second.get(true)}); + + error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(), + (int)c_args.size(), c_args.data()), + "could not execute a primitive"); +} + +/// @endcond + +#undef DNNL_DEFINE_BITMASK_OPS + +} // namespace dnnl + +/// oneAPI namespace + +/// The oneAPI namespace. +/// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace. +namespace oneapi { +// Note: without this guard, doxygen warns of potentially recursive namespace +#ifndef DOXYGEN_SHOULD_SKIP_THIS +/// oneDNN alias namespace +namespace dnnl = ::dnnl; +#endif +} // namespace oneapi + +/// @} dnnl_api + +#endif /* ONEAPI_DNNL_DNNL_HPP */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common.h new file mode 100644 index 0000000000000000000000000000000000000000..ff58a578cad363e337bea8909aa08303dec8c2d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common.h @@ -0,0 +1,175 @@ +/******************************************************************************* +* Copyright 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C common API + +#ifndef ONEAPI_DNNL_DNNL_COMMON_H +#define ONEAPI_DNNL_DNNL_COMMON_H + +#include "oneapi/dnnl/dnnl_common_types.h" +#include "oneapi/dnnl/dnnl_config.h" +#include "oneapi/dnnl/dnnl_version.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api oneDNN API +/// @{ + +/// @addtogroup dnnl_api_common Common API +/// @{ + +/// @addtogroup dnnl_api_engine Engine +/// @{ + +/// Returns the number of engines of a particular kind. +/// +/// @param kind Kind of engines to count. +/// @returns Count of the engines. +size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind); + +/// Creates an engine. +/// +/// @param engine Output engine. +/// @param kind Engine kind. +/// @param index Engine index that should be between 0 and the count of +/// engines of the requested kind. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_engine_create( + dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index); + +/// Returns the kind of an engine. +/// +/// @param engine Engine to query. +/// @param kind Output engine kind. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_engine_get_kind( + dnnl_engine_t engine, dnnl_engine_kind_t *kind); + +/// Destroys an engine. +/// +/// @param engine Engine to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine); + +/// @} dnnl_api_engine + +/// @addtogroup dnnl_api_stream Stream +/// @{ + +/// Creates an execution stream. +/// +/// @param stream Output execution stream. +/// @param engine Engine to create the execution stream on. +/// @param flags Stream behavior flags (@sa dnnl_stream_flags_t). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_stream_create( + dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags); + +/// Returns the engine of a stream object. +/// +/// @param stream Stream object. +/// @param engine Output engine on which the stream is created. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_stream_get_engine( + const_dnnl_stream_t stream, dnnl_engine_t *engine); + +/// Waits for all primitives in the execution stream to finish computations. +/// +/// @param stream Execution stream. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream); + +/// Destroys an execution stream. +/// +/// @param stream Execution stream to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream); + +/// @} dnnl_api_stream + +/// @addtogroup dnnl_api_fpmath_mode Floating-point Math Mode +/// @{ + +/// Returns the floating-point math mode that will be used by default +/// for all subsequently created primitives. +/// +/// @param mode Output FP math mode. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_get_default_fpmath_mode(dnnl_fpmath_mode_t *mode); + +/// Sets the floating-point math mode that will be used by default +/// for all subsequently created primitives. +/// +/// @param mode FP math mode. The possible values are: +/// #dnnl_fpmath_mode_strict, +/// #dnnl_fpmath_mode_bf16, +/// #dnnl_fpmath_mode_f16, +/// #dnnl_fpmath_mode_tf32, +/// #dnnl_fpmath_mode_any. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_set_default_fpmath_mode(dnnl_fpmath_mode_t mode); + +/// @} dnnl_api_fpmath_mode + +/// @addtogroup dnnl_api_service +/// @{ + +/// Configures verbose output to stdout. +/// +/// @note +/// Enabling verbose output affects performance. +/// This setting overrides the ONEDNN_VERBOSE environment variable. +/// +/// @param level Verbosity level: +/// - 0: no verbose output (default), +/// - 1: primitive and graph information at execution, +/// - 2: primitive and graph information at creation/compilation and execution. +/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the +/// @p level value is invalid, and #dnnl_success/#dnnl::status::success on +/// success. +dnnl_status_t DNNL_API dnnl_set_verbose(int level); + +/// Returns library version information. +/// @returns Pointer to a constant structure containing +/// - major: major version number, +/// - minor: minor version number, +/// - patch: patch release number, +/// - hash: git commit hash. +const dnnl_version_t DNNL_API *dnnl_version(void); + +/// @} dnnl_api_service + +/// @} dnnl_api_common + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif /* ONEAPI_DNNL_DNNL_COMMON_H */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..589be07382a3e581b142f7aac02042ebd198b9a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common.hpp @@ -0,0 +1,479 @@ +/******************************************************************************* +* Copyright 2022-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C++ common API + +#ifndef ONEAPI_DNNL_DNNL_COMMON_HPP +#define ONEAPI_DNNL_DNNL_COMMON_HPP + +/// @cond DO_NOT_DOCUMENT_THIS +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl_common.h" + +/// @endcond + +// __cpp_exceptions is referred from +// https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html +// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, +// Microsoft C++ Compiler does not provide an option to disable exceptions +#ifndef DNNL_ENABLE_EXCEPTIONS +#if __cpp_exceptions || __EXCEPTIONS \ + || (defined(_MSC_VER) && !defined(__clang__)) +#define DNNL_ENABLE_EXCEPTIONS 1 +#else +#define DNNL_ENABLE_EXCEPTIONS 0 +#endif +#endif + +#if defined(__GNUC__) || defined(__clang__) +#define DNNL_TRAP() __builtin_trap() +#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) +#define DNNL_TRAP() __debugbreak() +#else +#error "unknown compiler" +#endif + +#if DNNL_ENABLE_EXCEPTIONS +#define DNNL_THROW_ERROR(status, msg) throw error(status, msg) +#else +#include +#define DNNL_THROW_ERROR(status, msg) \ + do { \ + fputs(msg, stderr); \ + DNNL_TRAP(); \ + } while (0) +#endif + +/// @addtogroup dnnl_api oneDNN API +/// @{ + +/// oneDNN namespace +namespace dnnl { + +/// @addtogroup dnnl_api_common Common API +/// @{ + +/// @addtogroup dnnl_api_utils Utilities +/// Utility types and definitions. +/// @{ + +/// oneDNN exception class. +/// +/// This class captures the status returned by a failed C API function and +/// the error message from the call site. +struct error : public std::exception { + dnnl_status_t status; + const char *message; + + /// Constructs an instance of an exception class. + /// + /// @param status The error status returned by a C API function. + /// @param message The error message. + error(dnnl_status_t status, const char *message) + : status(status), message(message) {} + + /// Returns the explanatory string. + const char *what() const noexcept override { return message; } + + /// A convenience function for wrapping calls to C API functions. Checks + /// the return status and throws an dnnl::error in case of failure. + /// + /// @param status The error status returned by a C API function. + /// @param message The error message. + static void wrap_c_api(dnnl_status_t status, const char *message) { + if (status != dnnl_success) DNNL_THROW_ERROR(status, message); + } +}; + +/// A class that provides the destructor for a oneDNN C API handle. +template +struct handle_traits {}; + +/// oneDNN C API handle wrapper class. +/// +/// This class is used as the base class for primitive (dnnl::primitive), +/// engine (dnnl::engine), and stream (dnnl::stream) classes, as well as +/// others. An object of the dnnl::handle class can be passed by value. +/// +/// A handle can be weak, in which case it follows std::weak_ptr semantics. +/// Otherwise, it follows `std::shared_ptr` semantics. +/// +/// @note +/// The implementation stores oneDNN C API handles in a `std::shared_ptr` +/// with deleter set to a dummy function in the weak mode. +/// +template > +struct handle { +private: + static dnnl_status_t dummy_destructor(T) { return dnnl_success; } + std::shared_ptr::type> data_ {0}; + +protected: + bool operator==(const T other) const { return other == data_.get(); } + bool operator!=(const T other) const { return !(*this == other); } + +public: + /// Constructs an empty handle object. + /// + /// @warning + /// Uninitialized object cannot be used in most library calls and is + /// equivalent to a null pointer. Any attempt to use its methods, or + /// passing it to the other library function, will cause an exception + /// to be thrown. + handle() = default; + + /// Copy constructor. + handle(const handle &) = default; + /// Assignment operator. + handle &operator=(const handle &) = default; + /// Move constructor. + handle(handle &&) = default; + /// Move assignment operator. + handle &operator=(handle &&) = default; + + /// Constructs a handle wrapper object from a C API handle. + /// + /// @param t The C API handle to wrap. + /// @param weak A flag specifying whether to construct a weak wrapper; + /// defaults to @c false. + explicit handle(T t, bool weak = false) { reset(t, weak); } + + /// Resets the handle wrapper objects to wrap a new C API handle. + /// + /// @param t The new value of the C API handle. + /// @param weak A flag specifying whether the wrapper should be weak; + /// defaults to @c false. + void reset(T t, bool weak = false) { + data_.reset(t, weak ? &dummy_destructor : traits::destructor); + } + + /// Returns the underlying C API handle. + /// + /// @param allow_empty A flag signifying whether the method is allowed to + /// return an empty (null) object without throwing an exception. + /// @returns The underlying C API handle. + T get(bool allow_empty = false) const { + T result = data_.get(); + if (allow_empty == false && result == nullptr) + DNNL_THROW_ERROR( + dnnl_invalid_arguments, "object is not initialized"); + return result; + } + + /// Converts a handle to the underlying C API handle type. Does not throw + /// and returns `nullptr` if the object is empty. + /// + /// @returns The underlying C API handle. + explicit operator T() const { return get(true); } + + /// Checks whether the object is not empty. + /// + /// @returns Whether the object is not empty. + explicit operator bool() const { return get(true) != nullptr; } + + /// Equality operator. + /// + /// @param other Another handle wrapper. + /// @returns @c true if this and the other handle wrapper manage the same + /// underlying C API handle, and @c false otherwise. Empty handle + /// objects are considered to be equal. + bool operator==(const handle &other) const { + return other.data_.get() == data_.get(); + } + + /// Inequality operator. + /// + /// @param other Another handle wrapper. + /// @returns @c true if this and the other handle wrapper manage different + /// underlying C API handles, and @c false otherwise. Empty handle + /// objects are considered to be equal. + bool operator!=(const handle &other) const { return !(*this == other); } +}; + +/// @} dnnl_api_utils + +/// @addtogroup dnnl_api_engine Engine +/// +/// An abstraction of a computational device: a CPU, a specific GPU +/// card in the system, etc. Most primitives are created to execute +/// computations on one specific engine. The only exceptions are reorder +/// primitives that transfer data between two different engines. +/// +/// @sa @ref dev_guide_basic_concepts +/// +/// @{ + +/// @cond DO_NOT_DOCUMENT_THIS +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_engine_t p) { + return dnnl_engine_destroy(p); + } +}; +/// @endcond + +/// An execution engine. +struct engine : public handle { + friend struct primitive; + friend struct reorder; + + /// Kinds of engines. + enum class kind { + /// An unspecified engine + any = dnnl_any_engine, + /// CPU engine + cpu = dnnl_cpu, + /// GPU engine + gpu = dnnl_gpu, + }; + + using handle::handle; + + /// Constructs an empty engine. An empty engine cannot be used in any + /// operations. + engine() = default; + + /// Returns the number of engines of a certain kind. + /// + /// @param akind The kind of engines to count. + /// @returns The number of engines of the specified kind. + static size_t get_count(kind akind) { + return dnnl_engine_get_count(convert_to_c(akind)); + } + + /// Constructs an engine. + /// + /// @param akind The kind of engine to construct. + /// @param index The index of the engine. Must be less than the value + /// returned by #get_count() for this particular kind of engine. + engine(kind akind, size_t index) { + dnnl_engine_t engine; + error::wrap_c_api( + dnnl_engine_create(&engine, convert_to_c(akind), index), + "could not create an engine"); + reset(engine); + } + + /// Returns the kind of the engine. + /// @returns The kind of the engine. + kind get_kind() const { + dnnl_engine_kind_t kind; + error::wrap_c_api(dnnl_engine_get_kind(get(), &kind), + "could not get kind of an engine"); + return static_cast(kind); + } + +private: + static dnnl_engine_kind_t convert_to_c(kind akind) { + return static_cast(akind); + } +}; + +/// Converts engine kind enum value from C++ API to C API type. +/// +/// @param akind C++ API engine kind enum value. +/// @returns Corresponding C API engine kind enum value. +inline dnnl_engine_kind_t convert_to_c(engine::kind akind) { + return static_cast(akind); +} + +/// @} dnnl_api_engine + +/// @addtogroup dnnl_api_stream Stream +/// +/// An encapsulation of execution context tied to a particular engine. +/// +/// @sa @ref dev_guide_basic_concepts +/// +/// @{ + +/// @cond DO_NOT_DOCUMENT_THIS +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_stream_t p) { + return dnnl_stream_destroy(p); + } +}; +/// @endcond + +/// An execution stream. +struct stream : public handle { + using handle::handle; + + /// Stream flags. Can be combined using the bitwise OR operator. + enum class flags : unsigned { + /// In-order execution. + in_order = dnnl_stream_in_order, + /// Out-of-order execution. + out_of_order = dnnl_stream_out_of_order, + /// Default stream configuration. + default_flags = dnnl_stream_default_flags, +#ifdef DNNL_EXPERIMENTAL_PROFILING + /// Enables profiling capabilities. + profiling = dnnl_stream_profiling, +#endif + }; + + /// Constructs an empty stream. An empty stream cannot be used in any + /// operations. + stream() = default; + + /// Constructs a stream for the specified engine and with behavior + /// controlled by the specified flags. + /// + /// @param aengine Engine to create the stream on. + /// @param aflags Flags controlling stream behavior. + explicit stream( + const engine &aengine, flags aflags = flags::default_flags) { + dnnl_stream_t stream; + error::wrap_c_api(dnnl_stream_create(&stream, aengine.get(), + static_cast(aflags)), + "could not create a stream"); + reset(stream); + } + + /// Returns the associated engine. + engine get_engine() const { + dnnl_engine_t c_engine; + error::wrap_c_api(dnnl_stream_get_engine(get(), &c_engine), + "could not get an engine from a stream object"); + return engine(c_engine, true); + } + + /// Waits for all primitives executing in the stream to finish. + /// @returns The stream itself. + stream &wait() { + error::wrap_c_api( + dnnl_stream_wait(get()), "could not wait on a stream"); + return *this; + } +}; + +#define DNNL_DEFINE_BITMASK_OPS(enum_name) \ + inline enum_name operator|(enum_name lhs, enum_name rhs) { \ + return static_cast( \ + static_cast(lhs) | static_cast(rhs)); \ + } \ +\ + inline enum_name operator&(enum_name lhs, enum_name rhs) { \ + return static_cast( \ + static_cast(lhs) & static_cast(rhs)); \ + } \ +\ + inline enum_name operator^(enum_name lhs, enum_name rhs) { \ + return static_cast( \ + static_cast(lhs) ^ static_cast(rhs)); \ + } \ +\ + inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \ + lhs = static_cast( \ + static_cast(lhs) | static_cast(rhs)); \ + return lhs; \ + } \ +\ + inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \ + lhs = static_cast( \ + static_cast(lhs) & static_cast(rhs)); \ + return lhs; \ + } \ +\ + inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \ + lhs = static_cast( \ + static_cast(lhs) ^ static_cast(rhs)); \ + return lhs; \ + } \ +\ + inline enum_name operator~(enum_name rhs) { \ + return static_cast(~static_cast(rhs)); \ + } + +DNNL_DEFINE_BITMASK_OPS(stream::flags) + +/// @} dnnl_api_stream + +/// @addtogroup dnnl_api_fpmath_mode Floating-point Math Mode +/// @{ + +/// Floating-point math mode +enum class fpmath_mode { + /// Default behavior, no downconversions allowed + strict = dnnl_fpmath_mode_strict, + /// Implicit f32->bf16 conversions allowed + bf16 = dnnl_fpmath_mode_bf16, + /// Implicit f32->f16 conversions allowed + f16 = dnnl_fpmath_mode_f16, + /// Implicit f32->tf32 conversions allowed + tf32 = dnnl_fpmath_mode_tf32, + /// Implicit f32->f16, f32->tf32 or f32->bf16 conversions allowed + any = dnnl_fpmath_mode_any +}; + +/// Converts an fpmath mode enum value from C++ API to C API type. +/// +/// @param mode C++ API fpmath mode enum value. +/// @returns Corresponding C API fpmath mode enum value. +inline dnnl_fpmath_mode_t convert_to_c(fpmath_mode mode) { + return static_cast(mode); +} + +/// @} dnnl_api_fpmath_mode + +/// @addtogroup dnnl_api_accumulation_mode Accumulation Mode +/// @{ + +/// Accumulation mode +enum class accumulation_mode { + /// Default behavior, f32 for floating point computation, s32 for integer + strict = dnnl_accumulation_mode_strict, + /// same as strict except some partial accumulators can be rounded to + /// src/dst datatype in memory. + relaxed = dnnl_accumulation_mode_relaxed, + /// uses fastest implementation, could use src/dst datatype or + /// wider datatype for accumulators + any = dnnl_accumulation_mode_any, + /// use s32 accumulators during computation + s32 = dnnl_accumulation_mode_s32, + /// use f32 accumulators during computation + f32 = dnnl_accumulation_mode_f32, + /// use f16 accumulators during computation + f16 = dnnl_accumulation_mode_f16 +}; + +/// Converts an accumulation mode enum value from C++ API to C API type. +/// +/// @param mode C++ API accumulation mode enum value. +/// @returns Corresponding C API accumulation mode enum value. +inline dnnl_accumulation_mode_t convert_to_c(accumulation_mode mode) { + return static_cast(mode); +} + +/// @} dnnl_api_accumulation_mode + +/// @} dnnl_api_common + +} // namespace dnnl + +/// @} dnnl_api + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common_types.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common_types.h new file mode 100644 index 0000000000000000000000000000000000000000..d80a2e7d0a658821497da3d4ebf445b74b2bb0dd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_common_types.h @@ -0,0 +1,263 @@ +/******************************************************************************* +* Copyright 2022-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C API common types definitions + +#ifndef ONEAPI_DNNL_DNNL_COMMON_TYPES_H +#define ONEAPI_DNNL_DNNL_COMMON_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +/// @cond DO_NOT_DOCUMENT_THIS +#include +#include + +#include "oneapi/dnnl/dnnl_config.h" + +/// @endcond + +/// @addtogroup dnnl_api oneDNN API +/// @{ + +/// @addtogroup dnnl_api_common Common API +/// @{ + +/// @addtogroup dnnl_api_utils +/// @{ + +/// Status values returned by the library functions. +typedef enum { + /// The operation was successful + dnnl_success = 0, + /// The operation failed due to an out-of-memory condition + dnnl_out_of_memory = 1, + /// The operation failed because of incorrect function arguments + dnnl_invalid_arguments = 2, + /// The operation failed because requested functionality is not implemented + dnnl_unimplemented = 3, + /// The last available implementation is reached + dnnl_last_impl_reached = 4, + /// Primitive or engine failed on execution + dnnl_runtime_error = 5, + /// Queried element is not required for given primitive + dnnl_not_required = 6, + /// The graph is not legitimate + dnnl_invalid_graph = 7, + /// The operation is not legitimate according to op schema + dnnl_invalid_graph_op = 8, + /// The shape cannot be inferred or compiled + dnnl_invalid_shape = 9, + /// The data type cannot be inferred or compiled + dnnl_invalid_data_type = 10, +} dnnl_status_t; + +/// @} dnnl_api_utils + +/// @addtogroup dnnl_api_data_types Data types +/// @{ + +/// Data type specification +typedef enum { + /// Undefined data type, used for empty memory descriptors. + dnnl_data_type_undef = 0, + /// 16-bit/half-precision floating point. + dnnl_f16 = 1, + /// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. + dnnl_bf16 = 2, + /// 32-bit/single-precision floating point. + dnnl_f32 = 3, + /// 32-bit signed integer. + dnnl_s32 = 4, + /// 8-bit signed integer. + dnnl_s8 = 5, + /// 8-bit unsigned integer. + dnnl_u8 = 6, + /// 64-bit/double-precision floating point. + dnnl_f64 = 7, + /// Boolean data type. Size is C++ implementation defined. + dnnl_boolean = 8, + /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) + /// with a 5-bit exponent and a 2-bit mantissa. + dnnl_f8_e5m2 = 9, + /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) + /// with a 4-bit exponent and a 3-bit mantissa. + dnnl_f8_e4m3 = 10, + /// 4-bit signed integer. + dnnl_s4 = 11, + /// 4-bit unsigned integer. + dnnl_u4 = 12, + /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent. + dnnl_e8m0 = 13, + /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa. + dnnl_f4_e2m1 = 14, + /// 4-bit float data type with 3-bit exponent and 0 bit mantissa. + dnnl_f4_e3m0 = 15, + + /// Parameter to allow internal only data_types without undefined behavior. + /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. + dnnl_data_type_max = 0x7fff, +} dnnl_data_type_t; + +/// Maximum number of dimensions a tensor can have. Only restricts the amount +/// of space used for the tensor description. Individual computational +/// primitives may support only tensors of certain dimensions. +#define DNNL_MAX_NDIMS 12 + +/// A type to describe tensor dimension. +typedef int64_t dnnl_dim_t; + +/// A type to describe tensor dimensions. +typedef dnnl_dim_t dnnl_dims_t[DNNL_MAX_NDIMS]; + +/// @} dnnl_api_data_types + +/// @addtogroup dnnl_api_fpmath_mode Floating-point Math Mode +/// @{ + +/// Floating-point math mode +typedef enum { + /// Default behavior, no downconversions allowed + dnnl_fpmath_mode_strict, + /// Implicit f32->bf16 conversions allowed + dnnl_fpmath_mode_bf16, + /// Implicit f32->f16 conversions allowed + dnnl_fpmath_mode_f16, + /// Implicit f32->f16, f32->tf32 or f32->bf16 conversions allowed + dnnl_fpmath_mode_any, + /// Implicit f32->tf32 conversions allowed + dnnl_fpmath_mode_tf32, +} dnnl_fpmath_mode_t; + +/// @} dnnl_api_fpmath_mode + +/// @addtogroup dnnl_api_accumulation_mode Accumulation Mode +/// @{ + +/// Accumulation mode +typedef enum { + /// Default behavior, f32/f64 for floating point computation, s32 + /// for integer + dnnl_accumulation_mode_strict, + /// Same as strict but allows some partial accumulators to be + /// rounded to src/dst datatype in memory. + dnnl_accumulation_mode_relaxed, + /// uses fastest implementation, could use src/dst datatype or + /// wider datatype for accumulators + dnnl_accumulation_mode_any, + /// use s32 accumulators during computation + dnnl_accumulation_mode_s32, + /// use f32 accumulators during computation + dnnl_accumulation_mode_f32, + /// use f16 accumulators during computation + dnnl_accumulation_mode_f16 +} dnnl_accumulation_mode_t; + +/// @} dnnl_api_accumulation_mode + +/// @addtogroup dnnl_api_engine Engine +/// @{ + +/// @brief Kinds of engines. +typedef enum { + /// An unspecified engine. + dnnl_any_engine, + /// CPU engine. + dnnl_cpu, + /// GPU engine. + dnnl_gpu, +} dnnl_engine_kind_t; + +/// @struct dnnl_engine +/// @brief An opaque structure to describe an engine. +struct dnnl_engine; +/// @brief An engine handle. +typedef struct dnnl_engine *dnnl_engine_t; +#if 0 +// FIXME: looks like this never happens +/// @brief A constant engine handle. +typedef const struct dnnl_engine *const_dnnl_engine_t; +#endif + +/// @} dnnl_api_engine + +/// @addtogroup dnnl_api_stream Stream +/// @{ + +/// @brief Stream flags. +typedef enum { + // In-order execution. + dnnl_stream_in_order = 0x1U, + /// Out-of-order execution. + dnnl_stream_out_of_order = 0x2U, + /// Default stream configuration. + dnnl_stream_default_flags = dnnl_stream_in_order, +#ifdef DNNL_EXPERIMENTAL_PROFILING + /// Enables profiling capabilities. + dnnl_stream_profiling = 0x4U, +#endif +} dnnl_stream_flags_t; + +/// @struct dnnl_stream +/// An opaque structure to describe an execution stream. +struct dnnl_stream; +/// An execution stream handle. +typedef struct dnnl_stream *dnnl_stream_t; +/// A constant execution stream handle. +typedef const struct dnnl_stream *const_dnnl_stream_t; + +/// @} dnnl_api_stream + +/// @addtogroup dnnl_api_service +/// @{ + +/// Structure containing version information as per [Semantic +/// Versioning](https://semver.org) +typedef struct { + int major; ///< Major version + int minor; ///< Minor version + int patch; ///< Patch version + const char *hash; ///< Git hash of the sources (may be absent) + unsigned cpu_runtime; ///< CPU runtime + unsigned gpu_runtime; ///< GPU runtime +} dnnl_version_t; + +/// @} dnnl_api_service + +/// @addtogroup dnnl_api_memory +/// @{ + +/// Special pointer value that indicates that a memory object should not have +/// an underlying buffer. +#define DNNL_MEMORY_NONE (NULL) + +/// Special pointer value that indicates that the library needs to allocate an +/// underlying buffer for a memory object. +#define DNNL_MEMORY_ALLOCATE ((void *)(size_t)-1) + +/// @} dnnl_api_memory + +/// @} dnnl_api_common + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_config.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_config.h new file mode 100644 index 0000000000000000000000000000000000000000..dc58d8f1205cb4830dbb97c758547700db9797ea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_config.h @@ -0,0 +1,232 @@ +/******************************************************************************* +* Copyright 2019-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_CONFIG_H +#define ONEAPI_DNNL_DNNL_CONFIG_H + +/// @cond DO_NOT_DOCUMENT_THIS + +// All symbols shall be internal unless marked as DNNL_API +#if defined _WIN32 || defined __CYGWIN__ +#define DNNL_HELPER_DLL_IMPORT __declspec(dllimport) +#define DNNL_HELPER_DLL_EXPORT __declspec(dllexport) +#else +#if __GNUC__ >= 4 +#define DNNL_HELPER_DLL_IMPORT __attribute__((visibility("default"))) +#define DNNL_HELPER_DLL_EXPORT __attribute__((visibility("default"))) +#else +#define DNNL_HELPER_DLL_IMPORT +#define DNNL_HELPER_DLL_EXPORT +#endif +#endif + +#ifdef DNNL_DLL +#ifdef DNNL_DLL_EXPORTS +#define DNNL_API DNNL_HELPER_DLL_EXPORT +#else +#define DNNL_API DNNL_HELPER_DLL_IMPORT +#endif +#else +#define DNNL_API +#endif + +#if defined(__GNUC__) +#define DNNL_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define DNNL_DEPRECATED __declspec(deprecated) +#else +#define DNNL_DEPRECATED +#endif + +/// @endcond + +// clang-format off + +/// @addtogroup dnnl_api_service +/// @{ + +/// No runtime (disabled) +#define DNNL_RUNTIME_NONE 0u + +/// Sequential runtime (CPU only) +#define DNNL_RUNTIME_SEQ 1u + +/// OpenMP runtime (CPU only) +#define DNNL_RUNTIME_OMP 2u + +/// TBB runtime (CPU only) +#define DNNL_RUNTIME_TBB 4u + +/// Threadpool runtime (CPU only) +#define DNNL_RUNTIME_THREADPOOL 8u + +/// OpenCL runtime +#define DNNL_RUNTIME_OCL 256u + +/// SYCL runtime +#define DNNL_RUNTIME_SYCL 512u + +/// DPC++ runtime +#define DNNL_RUNTIME_DPCPP DNNL_RUNTIME_SYCL + +/// No vendor (corresponding runtime is disabled) +#define DNNL_VENDOR_NONE 0u + +/// Intel vendor +#define DNNL_VENDOR_INTEL 1u + +/// NVIDIA vendor +#define DNNL_VENDOR_NVIDIA 2u + +/// AMD vendor +#define DNNL_VENDOR_AMD 4u + +/// Generic vendor +#define DNNL_VENDOR_GENERIC 8u + +/// @} dnnl_api_service + +// oneDNN CPU threading runtime +#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP + +// oneDNN CPU engine runtime +#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP + +// oneDNN GPU engine runtime +#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE + +// oneDNN GPU vendor +#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE + +// clang-format on + +#if defined(DNNL_CPU_RUNTIME) && defined(DNNL_GPU_RUNTIME) +#if (DNNL_CPU_RUNTIME == DNNL_RUNTIME_OCL) +#error "Unexpected DNNL_CPU_RUNTIME" +#endif +#if (DNNL_GPU_RUNTIME != DNNL_RUNTIME_NONE) \ + && (DNNL_GPU_RUNTIME != DNNL_RUNTIME_OCL) \ + && (DNNL_GPU_RUNTIME != DNNL_RUNTIME_SYCL) +#error "Unexpected DNNL_GPU_RUNTIME" +#endif +#if (DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE \ + && DNNL_GPU_RUNTIME == DNNL_RUNTIME_NONE) +#error "At least one runtime must be specified" +#endif +#else +#error "BOTH DNNL_CPU_RUNTIME and DNNL_GPU_RUNTIME must be defined" +#endif + +// For SYCL CPU, a primitive may be created and executed in different threads +// hence the global scratchpad does not work. This enables concurrent execution +// when CPU runtime is SYCL to avoid the issue. +#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL +#ifndef DNNL_ENABLE_CONCURRENT_EXEC +#define DNNL_ENABLE_CONCURRENT_EXEC +#endif +#endif + +// When defined, primitive cache stores runtime objects. +/* #undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE */ + +// When defined, DPCPP is supported. +/* #undef DNNL_WITH_SYCL */ + +// When defined, Level Zero is supported. +/* #undef DNNL_WITH_LEVEL_ZERO */ + +// When defined, SYCL CUDA backend is used. +/* #undef DNNL_SYCL_CUDA */ + +// When defined, SYCL HIP backend is used. +/* #undef DNNL_SYCL_HIP */ + +// When defined, SYCL Generic backend is used. +/* #undef DNNL_SYCL_GENERIC */ + +// When defined, stack checker is enabled. +/* #undef DNNL_ENABLE_STACK_CHECKER */ + +// When defined, experimental features are enabled. +/* #undef DNNL_EXPERIMENTAL */ + +// When defined, experimental functionality for sparse domain is enabled. +/* #undef DNNL_EXPERIMENTAL_SPARSE */ + +// When defined, experimental functionality for ukernels is enabled. +#define DNNL_EXPERIMENTAL_UKERNEL + +// When defined, graph component is enabled. +#define ONEDNN_BUILD_GRAPH + +// When defined, experimental profiling capabilities are enabled. +/* #undef DNNL_EXPERIMENTAL_PROFILING */ + +// When defined, experimental logging capabilities are enabled. +/* #undef DNNL_EXPERIMENTAL_LOGGING */ +// When defined, it disables GPU compute reference kernels. +/* #undef DNNL_DISABLE_GPU_REF_KERNELS */ + +// List of configurating build controls +// Workload controls +#define BUILD_TRAINING 1 +#define BUILD_INFERENCE 0 +// Primitive controls +#define BUILD_PRIMITIVE_ALL 1 +#define BUILD_BATCH_NORMALIZATION 0 +#define BUILD_BINARY 0 +#define BUILD_CONCAT 0 +#define BUILD_CONVOLUTION 0 +#define BUILD_DECONVOLUTION 0 +#define BUILD_ELTWISE 0 +#define BUILD_GROUP_NORMALIZATION 0 +#define BUILD_INNER_PRODUCT 0 +#define BUILD_LAYER_NORMALIZATION 0 +#define BUILD_LRN 0 +#define BUILD_MATMUL 0 +#define BUILD_POOLING 0 +#define BUILD_PRELU 0 +#define BUILD_REDUCTION 0 +#define BUILD_REORDER 0 +#define BUILD_RESAMPLING 0 +#define BUILD_RNN 0 +#define BUILD_SDPA 0 +#define BUILD_SHUFFLE 0 +#define BUILD_SOFTMAX 0 +#define BUILD_SUM 0 +// Primitives CPU ISA controls +#define BUILD_PRIMITIVE_CPU_ISA_ALL 1 +#define BUILD_SSE41 0 +#define BUILD_AVX2 0 +#define BUILD_AVX512 0 +#define BUILD_AMX 0 +// Primitives GPU ISA controls +#define BUILD_PRIMITIVE_GPU_ISA_ALL 1 +#define BUILD_GEN9 0 +#define BUILD_GEN11 0 +#define BUILD_XELP 0 +#define BUILD_XEHP 0 +#define BUILD_XEHPG 0 +#define BUILD_XEHPC 0 +#define BUILD_XE2 0 +#define BUILD_XE3 0 +// GeMM kernels ISA controls +#define BUILD_GEMM_KERNELS_ALL 1 +#define BUILD_GEMM_KERNELS_NONE 0 +#define BUILD_GEMM_SSE41 0 +#define BUILD_GEMM_AVX2 0 +#define BUILD_GEMM_AVX512 0 +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_debug.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_debug.h new file mode 100644 index 0000000000000000000000000000000000000000..9efa63dd61e65283e7c9f72a9fba0ee081ac8847 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_debug.h @@ -0,0 +1,61 @@ +/******************************************************************************* +* Copyright 2018-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +// DO NOT EDIT, AUTO-GENERATED +// Use this script to update the file: scripts/generate_dnnl_debug.py + +// clang-format off + +#ifndef ONEAPI_DNNL_DNNL_DEBUG_H +#define ONEAPI_DNNL_DNNL_DEBUG_H + +/// @file +/// Debug capabilities + +#include "oneapi/dnnl/dnnl_config.h" +#include "oneapi/dnnl/dnnl_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +const char DNNL_API *dnnl_status2str(dnnl_status_t v); +const char DNNL_API *dnnl_dt2str(dnnl_data_type_t v); +const char DNNL_API *dnnl_fpmath_mode2str(dnnl_fpmath_mode_t v); +const char DNNL_API *dnnl_accumulation_mode2str(dnnl_accumulation_mode_t v); +const char DNNL_API *dnnl_engine_kind2str(dnnl_engine_kind_t v); +#ifdef DNNL_EXPERIMENTAL_SPARSE +const char DNNL_API *dnnl_sparse_encoding2str(dnnl_sparse_encoding_t v); +#endif +const char DNNL_API *dnnl_fmt_tag2str(dnnl_format_tag_t v); +const char DNNL_API *dnnl_prop_kind2str(dnnl_prop_kind_t v); +const char DNNL_API *dnnl_prim_kind2str(dnnl_primitive_kind_t v); +const char DNNL_API *dnnl_alg_kind2str(dnnl_alg_kind_t v); +const char DNNL_API *dnnl_rnn_flags2str(dnnl_rnn_flags_t v); +const char DNNL_API *dnnl_rnn_direction2str(dnnl_rnn_direction_t v); +const char DNNL_API *dnnl_scratchpad_mode2str(dnnl_scratchpad_mode_t v); +const char DNNL_API *dnnl_rounding_mode2str(dnnl_rounding_mode_t v); +const char DNNL_API *dnnl_cpu_isa2str(dnnl_cpu_isa_t v); +const char DNNL_API *dnnl_cpu_isa_hints2str(dnnl_cpu_isa_hints_t v); + +const char DNNL_API *dnnl_runtime2str(unsigned v); +const char DNNL_API *dnnl_fmt_kind2str(dnnl_format_kind_t v); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..77f7b46b48ffbe7ef1adce358545309ee94a4fcb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph.h @@ -0,0 +1,772 @@ +/******************************************************************************* +* Copyright 2020-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// Graph C API + +#ifndef ONEAPI_DNNL_DNNL_GRAPH_H +#define ONEAPI_DNNL_DNNL_GRAPH_H + +#include "oneapi/dnnl/dnnl_common.h" +#include "oneapi/dnnl/dnnl_config.h" +#include "oneapi/dnnl/dnnl_graph_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_graph_api +/// @{ + +/// @addtogroup dnnl_graph_api_allocator +/// @{ + +/// Creates a host allocator with the given allocation and deallocation +/// call-back function pointers. +/// +/// @param allocator Output allocator. +/// @param host_malloc A pointer to malloc function for host. +/// @param host_free A pointer to free function for host. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_allocator_create( + dnnl_graph_allocator_t *allocator, + dnnl_graph_host_allocate_f host_malloc, + dnnl_graph_host_deallocate_f host_free); + +/// Destroys an allocator. +/// +/// @param allocator The allocator to be destroyed. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_allocator_destroy( + dnnl_graph_allocator_t allocator); + +/// @} dnnl_graph_api_allocator + +/// @addtogroup dnnl_graph_api_engine +/// @{ + +/// This API is a supplement for existing onednn engine API. +dnnl_status_t DNNL_API dnnl_graph_make_engine_with_allocator( + dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index, + const_dnnl_graph_allocator_t alloc); + +/// @} dnnl_graph_api_engine + +/// @addtogroup dnnl_graph_api_logical_tensor +/// @{ + +/// Initializes a logical tensor with id, data type, number of dimensions, +/// layout type, and property. The logical tensor's dims are unknown with this +/// interface. +/// +/// @param logical_tensor Output logical tensor. +/// @param tid The unique id of the output logical tensor. +/// @param dtype Elements data type. +/// @param ndims Number of dimensions. +/// @param ltype Layout type of the underlying tensor buffer. +/// @param ptype Tensor property type. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_logical_tensor_init( + dnnl_graph_logical_tensor_t *logical_tensor, size_t tid, + dnnl_data_type_t dtype, int32_t ndims, dnnl_graph_layout_type_t ltype, + dnnl_graph_tensor_property_t ptype); + +/// Initializes a logical tensor with basic information and dims. The logical +/// tensor's dimensions and layout will be initialized according to the input +/// arguments. +/// +/// @note +/// If dims contains all valid values and layout type is +/// #dnnl_graph_layout_type_strided. The strides field in +/// #dnnl_graph_logical_tensor_t will be calculated in a row major and +/// contiguous way. Otherwise, Accessing the strides field is an undefined +/// behavior. +/// +/// Eg. dims (2, 3, 4, 5) will get strides (60, 20, 5, 1) +/// +/// @param logical_tensor Output logical tensor. +/// @param tid The unique id of output logical tensor. +/// @param dtype Elements data type. +/// @param ndims Number of dimensions. +/// @param dims Array of dimensions. +/// @param ltype Layout type of the underlying tensor memory. +/// @param ptype Tensor property type. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_logical_tensor_init_with_dims( + dnnl_graph_logical_tensor_t *logical_tensor, size_t tid, + dnnl_data_type_t dtype, int32_t ndims, const dnnl_dims_t dims, + dnnl_graph_layout_type_t ltype, dnnl_graph_tensor_property_t ptype); + +/// Initializes a logical tensor with dimensions and strides provided by user. +/// +/// @note +/// Once strides are explicitly provided through the API, the `layout_type` +/// in #dnnl_graph_logical_tensor_t can only be +/// #dnnl_graph_layout_type_strided or #dnnl_graph_layout_type_any. +/// +/// @param logical_tensor Output logical tensor. +/// @param tid The unique id of output logical tensor. +/// @param dtype Elements data type. +/// @param ndims Number of dimensions. +/// @param dims Array of dimensions. +/// @param strides Array of strides. +/// @param ptype Tensor property type. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_logical_tensor_init_with_strides( + dnnl_graph_logical_tensor_t *logical_tensor, size_t tid, + dnnl_data_type_t dtype, int32_t ndims, const dnnl_dims_t dims, + const dnnl_dims_t strides, dnnl_graph_tensor_property_t ptype); + +/// Returns the memory size described by the logical tensor. If it's a strided +/// layout, the size will be calculated by `dims` and `strides`. If it's an +/// opaque layout, the size will be decided by `layout_id`. +/// +/// @param logical_tensor Logical tensor. +/// @param size Output memory size in bytes. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_logical_tensor_get_mem_size( + const dnnl_graph_logical_tensor_t *logical_tensor, size_t *size); + +/// Compares if two logical tenors are equal. Users can decide accordingly +/// if layout reordering is needed for two logical tensors. The method will +/// return true for below two circumstances: +/// +/// 1. the two logical tensors are equal regarding each field in the struct, +/// eg. id, ndims, dims, layout type, property, etc. +/// 2. If all other fields are equal but the layout types in two logical +/// tensors are different, the method will return true when the underlying +/// memory layout is the same. For example, one logical tensor has strided +/// layout type while the other one has opaque layout type, but underneath, +/// both layouts are NHWC, the method will still return true for this case. +/// +/// @param lt1 The handle of first logical tensor. +/// @param lt2 The handle of second logical tensor. +/// @param is_equal 1 if these two logical tensors are equal, 0 otherwise. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_logical_tensor_is_equal( + const dnnl_graph_logical_tensor_t *lt1, + const dnnl_graph_logical_tensor_t *lt2, uint8_t *is_equal); + +/// @} dnnl_graph_api_logical_tensor + +/// @addtogroup dnnl_graph_api_tensor +/// @{ + +/// Creates a tensor with logical tensor, engine, and data handle. +/// +/// @param tensor Output tensor. +/// @param logical_tensor Description for this tensor. +/// @param engine Engine to use. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer for the tensor. In this case the library +/// owns the buffer. +/// - DNNL_MEMORY_NONE to create tensor without an underlying buffer. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_tensor_create(dnnl_graph_tensor_t *tensor, + const dnnl_graph_logical_tensor_t *logical_tensor, dnnl_engine_t engine, + void *handle); + +/// Destroys a tensor. +/// +/// @param tensor The tensor to be destroyed. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_tensor_destroy(dnnl_graph_tensor_t tensor); + +/// Gets the data handle of a tensor. +/// +/// @param tensor The input tensor. +/// @param handle Pointer to the data of input tensor. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_tensor_get_data_handle( + const_dnnl_graph_tensor_t tensor, void **handle); + +/// Set data handle for a tensor. +/// +/// @param tensor The input tensor. +/// @param handle New data handle for tensor. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_tensor_set_data_handle( + dnnl_graph_tensor_t tensor, void *handle); + +/// Returns the engine of a tensor object. +/// +/// @param tensor The input tensor. +/// @param engine Output engine on which the tensor is located. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_tensor_get_engine( + const_dnnl_graph_tensor_t tensor, dnnl_engine_t *engine); + +/// Returns the logical tensor of a tensor object. +/// +/// @param tensor The input tensor. +/// @param logical_tensor Output logical tensor of the tensor object. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_tensor_get_logical_tensor( + const_dnnl_graph_tensor_t tensor, + dnnl_graph_logical_tensor_t *logical_tensor); + +/// @} dnnl_graph_api_tensor + +/// @addtogroup dnnl_graph_api_op +/// @{ + +/// Initializes an op with unique id, kind, and name. +/// +/// @param op Output op +/// @param id The unique id of the output op. +/// @param kind The op kind. +/// @param verbose_name The string added as the op name. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_create(dnnl_graph_op_t *op, size_t id, + dnnl_graph_op_kind_t kind, const char *verbose_name); + +/// Destroys an op. +/// +/// @param op The op to be destroyed. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_destroy(dnnl_graph_op_t op); + +/// Adds input logical tensor to the op. +/// +/// @param op Input op. +/// @param input The input logical tensor to be added. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_add_input( + dnnl_graph_op_t op, const dnnl_graph_logical_tensor_t *input); + +/// Adds output logical tensor to the op. +/// +/// @param op Input op. +/// @param output The output logical tensor to be added. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_add_output( + dnnl_graph_op_t op, const dnnl_graph_logical_tensor_t *output); + +/// Sets floating point attribute to an op. +/// +/// @param op Input op. +/// @param name The attribute's name. +/// @param value The attribute's value. +/// @param value_len The number of value element. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_set_attr_f32(dnnl_graph_op_t op, + dnnl_graph_op_attr_t name, const float *value, size_t value_len); + +/// Sets boolean attribute to an op. +/// +/// @param op Input op. +/// @param name The attribute's name. +/// @param value The attribute's value. +/// @param value_len The number of value element. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_set_attr_bool(dnnl_graph_op_t op, + dnnl_graph_op_attr_t name, const uint8_t *value, size_t value_len); + +/// Sets integer attribute to an op. +/// +/// @param op Input op. +/// @param name The attribute's name. +/// @param value The attribute's value. +/// @param value_len The number of value element. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_set_attr_s64(dnnl_graph_op_t op, + dnnl_graph_op_attr_t name, const int64_t *value, size_t value_len); + +/// Sets string attribute to an op. +/// +/// @param op Input op. +/// @param name The attribute's name. +/// @param value The attribute's value. +/// @param value_len The length of the string value. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_set_attr_str(dnnl_graph_op_t op, + dnnl_graph_op_attr_t name, const char *value, size_t value_len); + +/// Returns the unique id of an op. +/// +/// @param op Input op. +/// @param id Output the unique id. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_get_id( + const_dnnl_graph_op_t op, size_t *id); + +/// Returns the kind of an op. +/// +/// @param op Input op. +/// @param kind Output op kind. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_op_get_kind( + const_dnnl_graph_op_t op, dnnl_graph_op_kind_t *kind); + +/// @} dnnl_graph_api_op + +/// @addtogroup dnnl_graph_api_partition +/// @{ + +/// Creates a new partition with a given operator and engine kind. The API is +/// used to create a partition from an operation directly without creating the +/// graph and calling `get_partitions()`. The output partition contains only one +/// operation specified by the parameter. The output partition instance should +/// be destroyed via #dnnl_graph_partition_destroy after use. +/// +/// @param partition The handle of output partition. +/// @param op The operation used to create partition. +/// @param ekind The engine kind used to create partition. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_create_with_op( + dnnl_graph_partition_t *partition, const_dnnl_graph_op_t op, + dnnl_engine_kind_t ekind); + +/// Destroys a partition. +/// +/// @param partition The partition to be destroyed. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_destroy( + dnnl_graph_partition_t partition); + +/// Returns the number of operations in a partition. +/// +/// @param partition The target partition. +/// @param num Output the number of operations. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_op_num( + const_dnnl_graph_partition_t partition, size_t *num); + +/// Returns the list of op IDs of the partition. +/// +/// @param partition The target partition. +/// @param num The number of ops. +/// @param ids Output the op IDs. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_ops( + dnnl_graph_partition_t partition, size_t num, size_t *ids); + +/// Returns the ID of a partition. +/// +/// @param partition The target partition. +/// @param id Output the ID of the partition. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_id( + const_dnnl_graph_partition_t partition, size_t *id); + +/// Compiles a partition with given input and output logical tensors. The output +/// logical tensors can contain unknown dimensions. For this case, the +/// compilation will deduce the output shapes according to input shapes. The +/// output logical tensors can also have layout type `any`. The compilation will +/// choose the optimal layout for output tensors. The optimal layout will be +/// represented as an opaque layout ID saved in the output logical tensor. +/// +/// @param partition The target partition. +/// @param compiled_partition Output compiled partition. +/// @param in_num The number of input logical tensors. +/// @param inputs A list of input logical tensors. +/// @param out_num The number of output logical tensors. +/// @param outputs A list of output logical tensors. +/// @param engine The target engine of the compilation. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_compile( + dnnl_graph_partition_t partition, + dnnl_graph_compiled_partition_t compiled_partition, size_t in_num, + const dnnl_graph_logical_tensor_t **inputs, size_t out_num, + const dnnl_graph_logical_tensor_t **outputs, dnnl_engine_t engine); + +/// Returns the number of input logical tensors of a partition. +/// +/// @param partition The target partition. +/// @param num Output the number of input logical tensors. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_input_ports_num( + const_dnnl_graph_partition_t partition, size_t *num); + +/// Returns a list of input logical tensors from a partition. +/// +/// @param partition The target partition. +/// @param num The number of input logical tensors. +/// @param inputs The list of input logical tensors. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_input_ports( + const_dnnl_graph_partition_t partition, size_t num, + dnnl_graph_logical_tensor_t *inputs); + +/// Returns the number of output logical tensors of a partition. +/// +/// @param partition The target partition. +/// @param num Output the number of output logical tensors. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_output_ports_num( + const_dnnl_graph_partition_t partition, size_t *num); + +/// Returns a list of output logical tensors from a partition. +/// +/// @param partition The target partition. +/// @param num The number of output logical tensors. +/// @param outputs The list of output logical tensors. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_output_ports( + const_dnnl_graph_partition_t partition, size_t num, + dnnl_graph_logical_tensor_t *outputs); + +/// Returns the supporting status of a partition. Some operations may not be +/// supported by the library under certain circumstances. During partitioning +/// stage, unsupported partitions will be returned to users with each containing +/// an unsupported operation. Users should check the supporting status of a +/// partition before transforming the computation graph or compiling the +/// partition. +/// +/// @param partition The target partition. +/// @param is_supported Output flag to indicate the supporting status. 0 means +/// unsupported while 1 means supported. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_is_supported( + const_dnnl_graph_partition_t partition, uint8_t *is_supported); + +/// Returns the engine kind of a partition. +/// +/// @param partition The target partition. +/// @param kind The output engine kind. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_partition_get_engine_kind( + const_dnnl_graph_partition_t partition, dnnl_engine_kind_t *kind); + +/// @} dnnl_graph_api_partition + +/// @addtogroup dnnl_graph_api_compiled_partition +/// @{ + +/// Creates a new compiled partition handle. +/// +/// @param compiled_partition The handle of output compiled partition. +/// @param partition The handle of input partition. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_compiled_partition_create( + dnnl_graph_compiled_partition_t *compiled_partition, + dnnl_graph_partition_t partition); + +/// Executes a compiled partition. +/// +/// @param compiled_partition The handle of target compiled partition. +/// @param stream The stream used for execution. +/// @param num_inputs The number of input tensors. +/// @param inputs A list of input tensors. +/// @param num_outputs The number of output tensors. +/// @param outputs A non-empty list of output tensors. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_compiled_partition_execute( + const_dnnl_graph_compiled_partition_t compiled_partition, + dnnl_stream_t stream, size_t num_inputs, + const_dnnl_graph_tensor_t *inputs, size_t num_outputs, + const_dnnl_graph_tensor_t *outputs); + +/// Destroys a compiled partition. +/// +/// @param compiled_partition The compiled partition to be destroyed. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_compiled_partition_destroy( + dnnl_graph_compiled_partition_t compiled_partition); + +/// Queries an input or output logical tensor according to tensor ID. If the +/// tensor ID doesn't belong to any input or output of the compiled partition, +/// an error status #dnnl_invalid_arguments will be returned by the API. +/// +/// @param compiled_partition The handle of target compiled_partition. +/// @param tid The unique id of required tensor. +/// @param lt The output logical tensor. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_compiled_partition_query_logical_tensor( + const_dnnl_graph_compiled_partition_t compiled_partition, size_t tid, + dnnl_graph_logical_tensor_t *lt); + +/// Returns the hint of in-place pairs from a compiled partition. It indicates +/// that an input and an output of the partition can share the same memory +/// buffer for computation. In-place computation helps to reduce the memory +/// footprint and improves cache locality. But since the library may not have a +/// global view of user's application, it's possible that the tensor with +/// `input_id` is used at other places in user's computation graph. In this +/// case, the user should take the in-place pair as a hint and pass a different +/// memory buffer for output tensor to avoid overwriting the input memory buffer +/// which will probably cause unexpected incorrect results. +/// +/// @param compiled_partition The handle of target compiled_partition. +/// @param num_inplace_pairs The number of in-place pairs. +/// @param inplace_pairs The handle of in-place pairs. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_compiled_partition_get_inplace_ports( + const_dnnl_graph_compiled_partition_t compiled_partition, + size_t *num_inplace_pairs, + const dnnl_graph_inplace_pair_t **inplace_pairs); + +/// @} dnnl_graph_api_compiled_partition + +/// @addtogroup dnnl_graph_api_graph +/// @{ + +/// Creates a new empty graph. A graph is associated to a specific engine kind. +/// The partitions returned from the graph will inherit the engine kind of the +/// graph. +/// +/// @param graph The handle of output graph. +/// @param engine_kind The target engine kind. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_create( + dnnl_graph_graph_t *graph, dnnl_engine_kind_t engine_kind); + +/// Creates a new empty graph with an engine kind and a floating-point math +/// mode. All partitions returned from the graph will inherit the engine kind +/// and floating-point math mode. +/// +/// @param graph The handle of output graph. +/// @param engine_kind The kind for engine. +/// @param mode The floating-point math mode. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_create_with_fpmath_mode( + dnnl_graph_graph_t *graph, dnnl_engine_kind_t engine_kind, + dnnl_fpmath_mode_t mode); + +/// Destroys a graph. +/// +/// @param graph The graph to be destroyed. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_destroy(dnnl_graph_graph_t graph); + +/// Set the floating point math mode for a graph. +/// +/// @param graph The target graph. +/// @param mode The floating-point math mode. +/// @param apply_to_int The flag that controls whether to use floating-point +/// arithmetic for integral operations. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_set_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t mode, int apply_to_int); + +/// Get the floating point math mode for a graph. +/// +/// @param graph The target graph. +/// @param mode The floating-point math mode. +/// @param apply_to_int The flag that controls whether to use floating-point +/// arithmetic for integral operations. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_get_fpmath_mode( + dnnl_graph_graph_t graph, dnnl_fpmath_mode_t *mode, int *apply_to_int); + +/// Adds an operation into a graph. The API will return failure if the operator +/// has already been added to the graph or the operation cannot pass the schema +/// check in the library (eg. input and output numbers and data types, the +/// attributes of the operation, etc.). +/// +/// @param graph The target graph. +/// @param op The operation to be added. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_add_op( + dnnl_graph_graph_t graph, dnnl_graph_op_t op); + +/// Finalizes a graph. It means users have finished adding operations into the +/// graph and the graph is ready for partitioning. Adding a new operation into a +/// finalized graph will return failures. Similarly, partitioning on a +/// un-finalized graph will also return failures. +/// +/// @param graph The target graph to be finalized. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_finalize(dnnl_graph_graph_t graph); + +/// Checks if a graph is finalized. +/// +/// @param graph The target graph to be finalized. +/// @param finalized Output the finalization status. 0 means then graph is not +/// finalized. Other values means the graph is finalized. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_is_finalized( + dnnl_graph_graph_t graph, uint8_t *finalized); + +/// Filters a graph. Partitions will be claimed internally according to the +/// capability of the library, the engine kind, and the policy. +/// +/// @param graph The target graph. +/// @param policy The partition policy. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_filter( + dnnl_graph_graph_t graph, dnnl_graph_partition_policy_t policy); + +/// Returns the number of partitions of a graph. The API should be called after +/// a partition is already filtered. Otherwise, the output number is zero. +/// +/// @param graph The graph. +/// @param num Output the number of partitions. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_get_partition_num( + const_dnnl_graph_graph_t graph, size_t *num); + +/// Returns the partitions from a filtered graph. Output partition instances +/// will be written into the parameter `partitions`. Users need to make sure +/// `partitions` is valid and has enough space to accept the partition +/// instances. Each output partition instance should be destroyed via +/// #dnnl_graph_partition_destroy explicitly after use. +/// +/// @param graph The target graph. +/// @param num The number of partitions. +/// @param partitions Output the partitions. +/// @returns #dnnl_success on success or a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_graph_get_partitions(dnnl_graph_graph_t graph, + size_t num, dnnl_graph_partition_t *partitions); + +/// @} dnnl_graph_api_graph + +/// @addtogroup dnnl_graph_api_compiled_partition_cache +/// @{ + +/// Returns the number of compiled partitions that can be held in the compiled +/// partition cache at the same time. +/// +/// @param capacity Compiled partition cache capacity to query. Concurrently +/// accessing @p capacity is safe. +/// @returns #dnnl_invalid_arguments if the @p capacity value +/// is invalid, and #dnnl_success on success. +dnnl_status_t DNNL_API dnnl_graph_get_compiled_partition_cache_capacity( + int *capacity); + +/// Sets a number of compiled partitions that can be held in the compiled +/// partition cache at the same time. The default capacity of compiled partition +/// cache is 1024. +/// +/// @param capacity Compiled partition cache capacity to set. The default cache +/// capacity is 1024. If a new @p capacity is less than a number of compiled +/// partition that the compiled partition cache already has, then the excess +/// entries will be evicted. Setting the @p capacity to 0 clears the compiled +/// partition cache and disables it. Concurrently modifying @p capacity is safe. +/// @returns #dnnl_invalid_arguments if the @p capacity value +/// is invalid, and #dnnl_success on success. +dnnl_status_t DNNL_API dnnl_graph_set_compiled_partition_cache_capacity( + int capacity); + +/// @} dnnl_graph_api_compiled_partition_cache + +/// @addtogroup dnnl_graph_api_constant_tensor_cache +/// @{ + +/// Control the enabling or disabling of constant tensor cache. This API must +/// be called once before compilation stage. By default, constant tensor cache is +/// disabled in the library. +/// +/// @param flag Set to positive value to enable the cache and set to 0 to +/// disable the cache. Negative values are invalid. +/// @returns #dnnl_invalid_arguments if the @p flag value is +/// invalid, and #dnnl_success on success. +/// @note This API is deprecated and will be removed in future release, please +/// use the dnnl_graph_set_constant_tensor_cache_capacity API to disable +/// constant tensor cache by setting it's capacity to zero. +dnnl_status_t DNNL_API dnnl_graph_set_constant_tensor_cache(int flag); + +/// Return the enabling or disabling status of constant tensor cache. +/// +/// @param flag The constant tensor cache enabling status to query. +/// @returns #dnnl_invalid_arguments if the @p flag value is +/// nullptr, and #dnnl_success on success. +/// @note This API is deprecated and will be removed in future release, please +/// use the dnnl_graph_get_constant_tensor_cache_capacity API to check the +/// enabling status by checking it's capacity. +dnnl_status_t DNNL_API dnnl_graph_get_constant_tensor_cache(int *flag); + +/// Control the capacity for the constant tensor cache that used for specific +/// engine kind. This API is thread safe and can be called multiple times at +/// runtime. The capacity is set to zero by default which means the cache is +/// disabled. When calling this API, the corresponding cache will be flushed. +/// Setting capacity to 0 means to clear all cached tensors and disable cache. +/// Once the capacity limit is reached, no new tensors will be cached. If there +/// are multiple devices for an engine kind, the capacity set here is for each +/// device. +/// +/// @param eng_kind The engine kind that the constant tensor cache used for. +/// @param size The constant tensor cache capacity size to set. +/// @returns #dnnl_invalid_arguments if the @p eng_kind value is invalid, and +/// #dnnl_success on success. +dnnl_status_t DNNL_API dnnl_graph_set_constant_tensor_cache_capacity( + dnnl_engine_kind_t eng_kind, size_t size); + +/// Return the current capacity of constant tensor cache. +/// +/// @param eng_kind The engine kind that the constant tensor cache used for. +/// @param size The constant tensor cache capacity size to query. +/// @returns #dnnl_invalid_arguments if the @p eng_kind value is +/// nullptr or the @p size is nullptr, and #dnnl_success on success. +dnnl_status_t DNNL_API dnnl_graph_get_constant_tensor_cache_capacity( + dnnl_engine_kind_t eng_kind, size_t *size); + +/// @} dnnl_graph_api_constant_tensor_cache + +/// @} dnnl_graph_api + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..37f228532b8a83434b349b9fa48f3d93361da8fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph.hpp @@ -0,0 +1,1634 @@ +/******************************************************************************* +* Copyright 2020-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// Graph C++ API + +#ifndef ONEAPI_DNNL_DNNL_GRAPH_HPP +#define ONEAPI_DNNL_DNNL_GRAPH_HPP + +#include "oneapi/dnnl/dnnl_common.hpp" +#include "oneapi/dnnl/dnnl_graph.h" + +#include +#include +#include +#include +#include + +/// @addtogroup dnnl_api +/// @{ + +namespace dnnl { + +/// @addtogroup dnnl_graph_api Graph API +/// oneDNN Graph API +/// @{ + +/// oneDNN Graph namespace +namespace graph { + +/// @cond DO_NOT_DOCUMENT_THIS + +// Alias for common engine and stream API. +using engine = dnnl::engine; +using stream = dnnl::stream; +using fpmath_mode = dnnl::fpmath_mode; + +/// @endcond + +/// @addtogroup dnnl_graph_api_utils Utilities +/// Utility types and definitions +/// @{ + +/// @cond DO_NOT_DOCUMENT_THIS + +/// A class that provides the destructor for a oneDNN graph C API handle. +template +struct graph_handle_traits : public dnnl::handle_traits {}; + +template <> +struct graph_handle_traits { + static dnnl_status_t destructor(dnnl_graph_op_t p) { + return dnnl_graph_op_destroy(p); + } +}; + +template <> +struct graph_handle_traits { + static dnnl_status_t destructor(dnnl_graph_graph_t p) { + return dnnl_graph_graph_destroy(p); + } +}; + +template <> +struct graph_handle_traits { + static dnnl_status_t destructor(dnnl_graph_tensor_t p) { + return dnnl_graph_tensor_destroy(p); + } +}; + +template <> +struct graph_handle_traits { + static dnnl_status_t destructor(dnnl_graph_partition_t p) { + return dnnl_graph_partition_destroy(p); + } +}; + +template <> +struct graph_handle_traits { + static dnnl_status_t destructor(dnnl_graph_compiled_partition_t p) { + return dnnl_graph_compiled_partition_destroy(p); + } +}; + +template <> +struct graph_handle_traits { + static dnnl_status_t destructor(dnnl_graph_allocator_t p) { + return dnnl_graph_allocator_destroy(p); + } +}; + +#define DNNL_GRAPH_HANDLE_ALIAS(type) \ + using type##_handle = dnnl::handle> + +DNNL_GRAPH_HANDLE_ALIAS(allocator); +DNNL_GRAPH_HANDLE_ALIAS(graph); +DNNL_GRAPH_HANDLE_ALIAS(op); +DNNL_GRAPH_HANDLE_ALIAS(tensor); +DNNL_GRAPH_HANDLE_ALIAS(compiled_partition); +DNNL_GRAPH_HANDLE_ALIAS(partition); + +#undef DNNL_GRAPH_HANDLE_ALIAS + +template +using req = typename std::enable_if::type; + +/// @endcond + +/// @} dnnl_graph_api_utils + +/// @addtogroup dnnl_graph_api_status Status +/// Definitions of status values returned by the library functions. +/// @{ + +/// Status values returned by the library functions. +enum class status { + /// The operation was successful + success = dnnl_success, + /// The operation failed due to an out-of-memory condition + out_of_memory = dnnl_out_of_memory, + /// The operation failed because of incorrect function arguments + invalid_arguments = dnnl_invalid_arguments, + /// The operation failed because requested functionality is not implemented + unimplemented = dnnl_unimplemented, + /// The last available implementation is reached + last_impl_reached = dnnl_last_impl_reached, + /// Primitive or engine failed on execution + runtime_error = dnnl_runtime_error, + /// Queried element is not required for given primitive + not_required = dnnl_not_required, + /// The graph is not legitimate + invalid_graph = dnnl_invalid_graph, + /// The operation is not legitimate according to op schema + invalid_graph_op = dnnl_invalid_graph_op, + /// The shape cannot be inferred or compiled + invalid_shape = dnnl_invalid_shape, + /// The data type cannot be inferred or compiled + invalid_data_type = dnnl_invalid_data_type, +}; + +/// @} dnnl_graph_api_status + +/// @addtogroup dnnl_graph_api_allocator Allocator +/// +/// Definitions of allocator which is used to acquire memory resources in +/// partition compilation and execution. SYCL allocator +/// (#dnnl::graph::sycl_interop::make_allocator) should be used for SYCL runtime +/// and host allocator should be used for non-SYCL. +/// +/// @{ + +/// Allocator +class allocator : public allocator_handle { +public: + using allocator_handle::handle; + + /// Constructs an allocator according to given function pointers + /// + /// @param host_malloc A pointer to malloc function for CPU + /// @param host_free A pointer to free function for CPU + allocator(dnnl_graph_host_allocate_f host_malloc, + dnnl_graph_host_deallocate_f host_free) { + dnnl_graph_allocator_t a = nullptr; + error::wrap_c_api( + dnnl_graph_allocator_create(&a, host_malloc, host_free), + "could not create allocator for cpu"); + reset(a); + } + + /// Default constructor + allocator() { + dnnl_graph_allocator_t a = nullptr; + error::wrap_c_api(dnnl_graph_allocator_create(&a, nullptr, nullptr), + "could not create allocator"); + reset(a); + } +}; + +/// @} dnnl_graph_api_allocator + +/// @addtogroup dnnl_graph_api_engine Engine +/// @{ + +/// This API is a supplement for existing onednn engine API. +inline engine make_engine_with_allocator( + engine::kind kind, size_t index, const allocator &alloc) { + dnnl_engine_t c_engine; + error::wrap_c_api( + dnnl_graph_make_engine_with_allocator(&c_engine, + static_cast(kind), index, alloc.get()), + "could not make an engine with allocator"); + return engine(c_engine); +} + +/// @} dnnl_graph_api_engine + +/// @addtogroup dnnl_graph_api_logical_tensor Logical Tensor +/// +/// Logical tensor describes the meta-data of the input or output tensor, like +/// elements data type, number of dimensions, size for each dimension (shape), +/// layout, and the property of the tensor. +/// +/// Each logical tensor has an unique ID. The library uses logical tensor IDs to +/// build up the connections between operations if the output of one operation +/// has the same ID as the input of another operation. The meta-data in a +/// logical tensor may be enriched in the framework graph as it progresses +/// toward final execution. For example, the library doesn't require detailed +/// shape information at the operation and graph creation stage. But shape +/// information of input logical tensor will be required at partition +/// compilation stage. Logical tensor is not mutable. Users must create a new +/// logical tensor with the same ID to pass any new additional information to +/// oneDNN Graph API. Please note that the library also has unique IDs for +/// operations. The ID should be unique among different logical tensors, but it +/// can have the same value between a logical tensor and an operation. +/// +/// @{ + +/// Logical tensor object +class logical_tensor { + friend class op; + friend class tensor; + friend class partition; + friend class compiled_partition; + + dnnl_graph_logical_tensor_t data; + +public: + /// Integer type for representing dimension sizes and indices. + using dim = dnnl_dim_t; + /// Vector of dimensions. Implementations are free to force a limit on the + /// vector's length. + using dims = std::vector; + + /// Data Type + enum class data_type { + undef = dnnl_data_type_undef, + /// 16-bit/half-precision floating point. + f16 = dnnl_f16, + /// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. + bf16 = dnnl_bf16, + /// 32-bit/single-precision floating point. + f32 = dnnl_f32, + /// 32-bit signed integer. + s32 = dnnl_s32, + /// 8-bit signed integer. + s8 = dnnl_s8, + /// 8-bit unsigned integer. + u8 = dnnl_u8, + /// Boolean data type. Size is C++ implementation defined. + boolean = dnnl_boolean, + /// [OFP8 standard 8-bit + /// floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) + /// with a 5-bit exponent and a 2-bit mantissa. + f8_e5m2 = dnnl_f8_e5m2, + /// [OFP8 standard 8-bit + /// floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) + /// with a 4-bit exponent and a 3-bit mantissa. + f8_e4m3 = dnnl_f8_e4m3, + /// 4-bit signed integer. + s4 = dnnl_s4, + /// 4-bit unsigned integer. + u4 = dnnl_u4, + }; + + /// Layout type + enum class layout_type { + /// Undefined layout type. + undef = dnnl_graph_layout_type_undef, + /// Any means to let the library to decide the layout for a tensor + /// during partition compilation. + any = dnnl_graph_layout_type_any, + /// Strided means that the layout of a tensor is determined by the + /// strides field in the logical tensor. + strided = dnnl_graph_layout_type_strided, + /// Opaque means that the layout of a tensor is the library specific. + /// Usually, an opaque layout is generated by a partition which is + /// compiled with layout type any. + opaque = dnnl_graph_layout_type_opaque, + }; + + /// Tensor property + enum class property_type { + /// Undefined tensor property. + undef = dnnl_graph_tensor_property_undef, + /// Variable means the tensor may be changed during computation or + /// between different iterations. + variable = dnnl_graph_tensor_property_variable, + /// Constant means the tensor will keep unchanged during computation and + /// between different iterations. It's useful for the library to apply + /// optimizations for constant tensors or cache constant tensors inside + /// the library. For example, constant weight tensors in inference + /// scenarios. + constant = dnnl_graph_tensor_property_constant, + }; + + /// default constructor + /// construct an empty object + logical_tensor() = default; + + /// Constructs a logical tensor object + explicit logical_tensor(const dnnl_graph_logical_tensor_t &c_data) + : data(c_data) {} + + /// Copy + logical_tensor(const logical_tensor &other) = default; + + /// Assign + logical_tensor &operator=(const logical_tensor &other) = default; + + /// Constructs a logical tensor object with ID, data type, ndims, layout + /// type, and property type. + /// + /// @param tid Logical tensor ID. + /// @param dtype Elements data type. + /// @param ndims Number of dimensions. -1 means unknown (see + /// #DNNL_GRAPH_UNKNOWN_NDIMS) and 0 means a scalar tensor. + /// @param ltype Layout type. + /// @param ptype Property type. + logical_tensor(size_t tid, data_type dtype, int32_t ndims, + layout_type ltype, property_type ptype = property_type::undef) { + dnnl_graph_logical_tensor_t val; + error::wrap_c_api( + dnnl_graph_logical_tensor_init(&val, tid, convert_to_c(dtype), + ndims, convert_to_c(ltype), convert_to_c(ptype)), + "could not create logical_tensor with property"); + data = val; + } + + /// Delegated constructor. + /// + /// @param tid Logical tensor ID. + /// @param dtype Elements data type. + /// @param ltype Layout type. + logical_tensor( + size_t tid, data_type dtype, layout_type ltype = layout_type::undef) + : logical_tensor(tid, dtype, DNNL_GRAPH_UNKNOWN_NDIMS, ltype) {} + + /// Constructs a logical tensor object with basic information and detailed + /// dims. + /// + /// @param tid Logical tensor ID. + /// @param dtype Elements data type. + /// @param adims Logical tensor dimensions. #DNNL_GRAPH_UNKNOWN_DIM means + /// the size of that dimension is unknown. 0 is used to define + /// zero-dimension tensor. + /// @param ltype Layout type. If it's strided, the strides field in the + /// output logical tensor will be deduced accordingly. + /// @param ptype Property type. + logical_tensor(size_t tid, data_type dtype, const dims &adims, + layout_type ltype, property_type ptype = property_type::undef) { + dnnl_graph_logical_tensor_t val; + // if dimension size equals to 0, it's a scalar + if (adims.size() == 0) + error::wrap_c_api(dnnl_graph_logical_tensor_init(&val, tid, + convert_to_c(dtype), 0, + convert_to_c(ltype), convert_to_c(ptype)), + "could not create logical_tensor with property"); + else + error::wrap_c_api( + dnnl_graph_logical_tensor_init_with_dims(&val, tid, + convert_to_c(dtype), + static_cast(adims.size()), adims.data(), + convert_to_c(ltype), convert_to_c(ptype)), + "could not create logical_tensor with dims and property"); + data = val; + } + + /// Constructs a logical tensor object with detailed dims and strides. The + /// layout_type of the output logical tensor object will always be strided. + /// + /// @param tid Logical tensor ID. + /// @param dtype Elements data type. + /// @param adims Logical tensor dimensions. #DNNL_GRAPH_UNKNOWN_DIM means + /// the size of that dimension is unknown. 0 is used to define + /// zero-dimension tensor. + /// @param strides Logical tensor strides. #DNNL_GRAPH_UNKNOWN_DIM means + /// the stride of the dimension is unknown. The library currently + /// doesn't support other negative stride values. + /// @param ptype Property type. + logical_tensor(size_t tid, data_type dtype, const dims &adims, + const dims &strides, property_type ptype = property_type::undef) { + dnnl_graph_logical_tensor_t val; + // TODO(lvtao): check the size of adims and strides. + // They should be same. + error::wrap_c_api( + dnnl_graph_logical_tensor_init_with_strides(&val, tid, + convert_to_c(dtype), static_cast(adims.size()), + adims.data(), strides.data(), convert_to_c(ptype)), + "could not create logical_tensor with strides and property"); + data = val; + } + + /// Constructs a logical tensor object with detailed dims and an opaque + /// layout ID. layout_type of the output logical tensor object will always + /// be opaque. + /// + /// @param tid Logical tensor ID. + /// @param dtype Elements data type. + /// @param adims Logical tensor dimensions. #DNNL_GRAPH_UNKNOWN_DIM means + /// the size of that dimension is unknown. 0 is used to define + /// zero-dimension tensor. + /// @param lid Opaque layout id. + /// @param ptype Property type + logical_tensor(size_t tid, data_type dtype, const dims &adims, size_t lid, + property_type ptype = property_type::undef) { + dnnl_graph_logical_tensor_t val; + + if (adims.size() == 0) { + error::wrap_c_api(dnnl_graph_logical_tensor_init(&val, tid, + convert_to_c(dtype), 0, + convert_to_c(layout_type::opaque), + convert_to_c(ptype)), + "could not create logical_tensor"); + } else { + error::wrap_c_api( + dnnl_graph_logical_tensor_init_with_dims(&val, tid, + convert_to_c(dtype), + static_cast(adims.size()), adims.data(), + convert_to_c(layout_type::opaque), + convert_to_c(ptype)), + "could not create logical_tensor with dims"); + } + + val.layout.layout_id = lid; + data = val; + } + + /// Returns dimensions of a logical tensor. + /// + /// @returns A vector describing the size of each dimension. + dims get_dims() const { + if (data.ndims < 0) { + error::wrap_c_api(dnnl_invalid_arguments, + "cannot return dims when ndims < 0"); + } + + return {data.dims, data.dims + data.ndims}; + } + + /// Returns the unique id of a logical tensor. + /// + /// @returns An integer value describing the ID. + size_t get_id() const { return data.id; } + + /// Returns the data type of a logical tensor. + /// + /// @returns The data type. + data_type get_data_type() const { + return static_cast(data.data_type); + } + + /// Returns the property type of a logical tensor. + /// + /// @returns The property type. + property_type get_property_type() const { + return static_cast(data.property); + } + + /// Returns the layout type of a logical tensor. + /// + /// @returns The layout type. + layout_type get_layout_type() const { + return static_cast(data.layout_type); + } + + /// Returns the layout ID of a logical tensor. The API should be called on a + /// logical tensor with opaque layout type. Otherwise, an exception will be + /// raised. + /// + /// @returns Layout ID. + size_t get_layout_id() const { + if (get_layout_type() != layout_type::opaque) { + error::wrap_c_api( + dnnl_invalid_arguments, "layout type should be opaque"); + } + + return data.layout.layout_id; + } + + /// Returns the strides of a logical tensor. The API should be called on a + /// logical tensor with strided layout type. Otherwise, an exception will be + /// raised. + /// + /// @returns A vector describing the stride size of each dimension. + dims get_strides() const { + if (get_layout_type() != layout_type::strided) { + error::wrap_c_api( + dnnl_invalid_arguments, "layout type should be strided"); + } + + if (data.ndims < 0) { + error::wrap_c_api(dnnl_invalid_arguments, + "cannot return strides when ndims < 0"); + } + + return {data.layout.strides, data.layout.strides + data.ndims}; + } + + /// Returns memory size in bytes required by this logical tensor. + /// + /// @returns The memory size in bytes. + size_t get_mem_size() const { + size_t size = 0; + error::wrap_c_api(dnnl_graph_logical_tensor_get_mem_size(&data, &size), + "could not get memory size from the logical_tensor"); + return size; + } + + /// Compares if two logical tenors are equal. Users can decide accordingly + /// if layout reordering is needed for two logical tensors. The method will + /// return true for below two circumstances: + /// + /// 1. the two logical tensors are equal regarding each field in the struct, + /// eg. id, ndims, dims, layout type, property, etc. + /// 2. If all other fields are equal but the layout types in two logical + /// tensors are different, the method will return true when the underlying + /// memory layout is the same. For example, one logical tensor has strided + /// layout type while the other one has opaque layout type, but underneath, + /// both layouts are NHWC, the method will still return true for this case. + /// + /// @param lt The input logical tensor to be compared. + /// @returns @c true if the two logical tensors are equal. @c false otherwise + bool is_equal(const logical_tensor <) const { + uint8_t equal = 0; + error::wrap_c_api( + dnnl_graph_logical_tensor_is_equal(&data, <.data, &equal), + "could not compare between the two logical tensors"); + return equal != 0; + } + +private: + static dnnl_data_type_t convert_to_c(data_type dtype) { + return static_cast(dtype); + } + + static dnnl_graph_layout_type_t convert_to_c(layout_type ltype) { + return static_cast(ltype); + } + + static dnnl_graph_tensor_property_t convert_to_c(property_type ptype) { + return static_cast(ptype); + } +}; + +/// @} dnnl_graph_api_logical_tensor + +/// @addtogroup dnnl_graph_api_tensor Tensor +/// +/// Tensor is an abstraction for multi-dimensional input and output data needed +/// in the execution of a compiled partition. A tensor object encapsulates a +/// handle to a memory buffer allocated on a specific engine and a logical +/// tensor which describes the dimensions, elements data type, and memory +/// layout. +/// +/// @{ + +/// A tensor object +class tensor : public tensor_handle { +public: + /// Default constructor. Constructs an empty object. + tensor() = default; + + /// Constructs a tensor object according to a given logical tensor, an + /// engine, and a memory handle. + /// + /// @param lt The given logical tensor + /// @param aengine Engine to store the data on. + /// @param handle Handle of memory buffer to use as an underlying storage. + /// - A pointer to the user-allocated buffer. In this case the library + /// doesn't own the buffer. + /// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to + /// allocate the buffer for the tensor. In this case the library + /// owns the buffer. + /// - DNNL_MEMORY_NONE to create tensor without an underlying buffer. + tensor(const logical_tensor <, const engine &aengine, void *handle) { + dnnl_graph_tensor_t t = nullptr; + error::wrap_c_api( + dnnl_graph_tensor_create(&t, &(lt.data), aengine.get(), handle), + "could not create tensor object with the logical_tensor, " + "engine, and handle"); + reset(t); + } + + /// Constructs a tensor object. + /// The underlying buffer for the memory will be allocated by the library. + /// + /// @param lt The given logical tensor + /// @param aengine Engine to store the data on. + tensor(const logical_tensor <, const engine &aengine) + : tensor(lt, aengine, DNNL_MEMORY_ALLOCATE) {} + + /// Returns the underlying memory buffer. + /// + /// On the CPU engine, or when using USM, this is a pointer to the + /// allocated memory. + void *get_data_handle() const { + void *handle = nullptr; + error::wrap_c_api(dnnl_graph_tensor_get_data_handle(get(), &handle), + "could not get data handle from the tensor"); + return handle; + } + + /// Sets the underlying memory handle. + /// + /// @param handle Memory handle. + void set_data_handle(void *handle) { + error::wrap_c_api(dnnl_graph_tensor_set_data_handle(get(), handle), + "setting data handle to the tensor failed"); + } + + /// Returns the associated engine. + /// + /// @returns An engine object + engine get_engine() const { + dnnl_engine_t c_engine = nullptr; + error::wrap_c_api(dnnl_graph_tensor_get_engine(get(), &c_engine), + "could not get an engine from a tensor object"); + return engine(c_engine, true); + } + + /// Returns the logical tensor of a tensor object. + /// + /// @returns A logical_tensor object. + logical_tensor get_logical_tensor() const { + dnnl_graph_logical_tensor_t lt; + error::wrap_c_api(dnnl_graph_tensor_get_logical_tensor(get(), <), + "could not get logical tensor from a tensor object"); + return logical_tensor(lt); + } +}; + +/// @} dnnl_graph_api_tensor + +/// @addtogroup dnnl_graph_api_compiled_partition Compiled Partition +/// +/// A compiled partition represents the generated kernels specialized for a +/// partition on a target hardware (engine) with input and output information +/// specified by the logical tensors. +/// +/// @{ + +/// A compiled partition object. +class compiled_partition : public compiled_partition_handle { +public: + /// Default constructor. Constructs an empty object. + compiled_partition() = default; + + /// Constructs a compiled partition object + compiled_partition(dnnl_graph_compiled_partition_t compiled_partition) { + reset(compiled_partition, false); + } + + /// Queries an input or output logical tensor according to tensor ID. If the + /// tensor ID doesn't belong to any input or output of the compiled + /// partition, an exception will be raised by the API. + /// + /// @param tid The unique id of required tensor. + /// @returns The logical tensor. + logical_tensor query_logical_tensor(size_t tid) const { + dnnl_graph_logical_tensor_t lt; + error::wrap_c_api(dnnl_graph_compiled_partition_query_logical_tensor( + get(), tid, <), + "query logical tensor from compiled_partition failed"); + return logical_tensor {lt}; + } + + /// Returns the hint of in-place pairs from a compiled partition. It + /// indicates that an input and an output of the partition can share the + /// same memory buffer for computation. In-place computation helps to reduce + /// the memory footprint and improves cache locality. But since the library + /// may not have a global view of user's application, it's possible that the + /// input tensor is used at other places in user's computation graph. In + /// this case, the user should take the in-place pair as a hint and pass a + /// different memory buffer for output tensor to avoid overwriting the input + /// memory buffer which will probably cause unexpected incorrect results. + /// + /// @returns A list of pairs of input and output IDs. + std::vector> get_inplace_ports() const { + size_t num = 0; + const dnnl_graph_inplace_pair_t *inplace_pairs; + + error::wrap_c_api(dnnl_graph_compiled_partition_get_inplace_ports( + get(), &num, &inplace_pairs), + "could not get the in-place pairs from a compiled partition"); + if (num == 0) return {}; + + std::vector> inplace_options; + inplace_options.reserve(num); + for (size_t i = 0; i < num; ++i) { + const dnnl_graph_inplace_pair_t *inplace_pair = inplace_pairs + i; + inplace_options.emplace_back( + inplace_pair->input_id, inplace_pair->output_id); + } + return inplace_options; + } + + /// Execute a compiled partition. + /// + /// @param astream Stream object to run over. + /// @param inputs A list of input tensors. + /// @param outputs A list of output tensors. + void execute(stream &astream, const std::vector &inputs, + const std::vector &outputs) const { + std::vector c_inputs; + c_inputs.reserve(inputs.size()); + for (auto &in : inputs) { + c_inputs.push_back(in.get()); + } + std::vector c_outputs; + c_outputs.reserve(outputs.size()); + for (auto &out : outputs) { + c_outputs.push_back(out.get()); + } + + error::wrap_c_api( + dnnl_graph_compiled_partition_execute(get(), astream.get(), + c_inputs.size(), c_inputs.data(), c_outputs.size(), + c_outputs.data()), + "could not execute the compiled_partition"); + } +}; + +/// @} dnnl_graph_api_compiled_partition + +/// @addtogroup dnnl_graph_api_op Op +/// +/// OP is an abstraction of computation logic for deep neural network +/// operations. An op object encapsulates an operation kind which describes the +/// computation logic, an unique ID which differentiates operations with the +/// same kind, and logical tensors which describes the input and output of the +/// operation and its connections to other operations in the graph. +/// +/// @{ + +/// An op object. +class op : public op_handle { +public: + /// Kinds of operations + enum class kind { + Abs = dnnl_graph_op_abs, + AbsBackward = dnnl_graph_op_abs_backward, + Add = dnnl_graph_op_add, + AvgPool = dnnl_graph_op_avg_pool, + AvgPoolBackward = dnnl_graph_op_avg_pool_backward, + BatchNormForwardTraining = dnnl_graph_op_batch_norm_forward_training, + BatchNormInference = dnnl_graph_op_batch_norm_inference, + BatchNormTrainingBackward = dnnl_graph_op_batch_norm_backward, + BiasAdd = dnnl_graph_op_bias_add, + BiasAddBackward = dnnl_graph_op_bias_add_backward, + Clamp = dnnl_graph_op_clamp, + ClampBackward = dnnl_graph_op_clamp_backward, + Concat = dnnl_graph_op_concat, + Convolution = dnnl_graph_op_convolution, + ConvolutionBackwardData = dnnl_graph_op_convolution_backward_data, + ConvolutionBackwardWeights = dnnl_graph_op_convolution_backward_weights, + ConvTranspose = dnnl_graph_op_conv_transpose, + ConvTransposeBackwardData = dnnl_graph_op_conv_transpose_backward_data, + ConvTransposeBackwardWeights + = dnnl_graph_op_conv_transpose_backward_weights, + Dequantize = dnnl_graph_op_dequantize, + Divide = dnnl_graph_op_divide, + DynamicDequantize = dnnl_graph_op_dynamic_dequantize, + DynamicQuantize = dnnl_graph_op_dynamic_quantize, + Elu = dnnl_graph_op_elu, + EluBackward = dnnl_graph_op_elu_backward, + End = dnnl_graph_op_end, + Exp = dnnl_graph_op_exp, + GELU = dnnl_graph_op_gelu, + GELUBackward = dnnl_graph_op_gelu_backward, + GroupNorm = dnnl_graph_op_group_norm, + HardSigmoid = dnnl_graph_op_hard_sigmoid, + HardSigmoidBackward = dnnl_graph_op_hard_sigmoid_backward, + HardSwish = dnnl_graph_op_hard_swish, + HardSwishBackward = dnnl_graph_op_hard_swish_backward, + Interpolate = dnnl_graph_op_interpolate, + InterpolateBackward = dnnl_graph_op_interpolate_backward, + LayerNorm = dnnl_graph_op_layer_norm, + LayerNormBackward = dnnl_graph_op_layer_norm_backward, + LeakyReLU = dnnl_graph_op_leaky_relu, + Log = dnnl_graph_op_log, + LogSoftmax = dnnl_graph_op_log_softmax, + LogSoftmaxBackward = dnnl_graph_op_log_softmax_backward, + MatMul = dnnl_graph_op_matmul, + Maximum = dnnl_graph_op_maximum, + MaxPool = dnnl_graph_op_max_pool, + MaxPoolBackward = dnnl_graph_op_max_pool_backward, + Minimum = dnnl_graph_op_minimum, + Mish = dnnl_graph_op_mish, + MishBackward = dnnl_graph_op_mish_backward, + Multiply = dnnl_graph_op_multiply, + Pow = dnnl_graph_op_pow, + PReLU = dnnl_graph_op_prelu, + PReLUBackward = dnnl_graph_op_prelu_backward, + Quantize = dnnl_graph_op_quantize, + Reciprocal = dnnl_graph_op_reciprocal, + ReduceL1 = dnnl_graph_op_reduce_l1, + ReduceL2 = dnnl_graph_op_reduce_l2, + ReduceMax = dnnl_graph_op_reduce_max, + ReduceMean = dnnl_graph_op_reduce_mean, + ReduceMin = dnnl_graph_op_reduce_min, + ReduceProd = dnnl_graph_op_reduce_prod, + ReduceSum = dnnl_graph_op_reduce_sum, + ReLU = dnnl_graph_op_relu, + ReLUBackward = dnnl_graph_op_relu_backward, + Reorder = dnnl_graph_op_reorder, + Round = dnnl_graph_op_round, + Select = dnnl_graph_op_select, + Sigmoid = dnnl_graph_op_sigmoid, + SigmoidBackward = dnnl_graph_op_sigmoid_backward, + SoftMax = dnnl_graph_op_softmax, + SoftMaxBackward = dnnl_graph_op_softmax_backward, + SoftPlus = dnnl_graph_op_softplus, + SoftPlusBackward = dnnl_graph_op_softplus_backward, + Sqrt = dnnl_graph_op_sqrt, + SqrtBackward = dnnl_graph_op_sqrt_backward, + Square = dnnl_graph_op_square, + SquaredDifference = dnnl_graph_op_squared_difference, + StaticReshape = dnnl_graph_op_static_reshape, + StaticTranspose = dnnl_graph_op_static_transpose, + Subtract = dnnl_graph_op_subtract, + Tanh = dnnl_graph_op_tanh, + TanhBackward = dnnl_graph_op_tanh_backward, + TypeCast = dnnl_graph_op_type_cast, + Wildcard = dnnl_graph_op_wildcard, + GenIndex = dnnl_graph_op_gen_index, + GreaterEqual = dnnl_graph_op_greater_equal, + // Sentinel + LastSymbol = dnnl_graph_op_last_symbol, + }; + + /// Attributes of operations. Different operations support different + /// attributes. Check the document of each operation for what attributes are + /// supported and what are the potential values for them. Missing required + /// attribute or illegal attribute value may lead to failure when adding the + /// operation to a graph. + enum class attr { + /// Undefined op attribute. + undef = dnnl_graph_op_attr_undef, + + // float32 attributes. The value of these attributes can be any single + // float32 number. + + /// Specifies an alpha attribute to an op. + alpha = dnnl_graph_op_attr_alpha, + /// Specifies an beta attribute to an op. + beta = dnnl_graph_op_attr_beta, + /// Specifies an epsilon attribute to an op. + epsilon = dnnl_graph_op_attr_epsilon, + /// Specifies a max attribute to an op. + max = dnnl_graph_op_attr_max, + /// Specifies a min attribute to an op. + min = dnnl_graph_op_attr_min, + /// Specifies a momentum attribute to an op. + momentum = dnnl_graph_op_attr_momentum, + + // float32 vector attributes. The value of these attributes can be a + // vector of float32 numbers. + + /// Specifies a scales attribute to an op. + scales = dnnl_graph_op_attr_scales, + + // int64_t attributes. The value of these attributes can be any single + // int64 number. + + /// Specifies an axis attribute to an op. + axis = dnnl_graph_op_attr_axis, + /// Specifies a begin_norm_axis attribute to an op. + begin_norm_axis = dnnl_graph_op_attr_begin_norm_axis, + /// Specifies a groups attribute to an op. + groups = dnnl_graph_op_attr_groups, + + // int64_t vector attributes. The value of these attributes can be a + // vector of int64 numbers. + + /// Specifies an axes attribute to an op. + axes = dnnl_graph_op_attr_axes, + /// Specifies a dilations attribute to an op. + dilations = dnnl_graph_op_attr_dilations, + /// Specifies an dst_shape attribute to an op. + dst_shape = dnnl_graph_op_attr_dst_shape, + /// Specifies a kernel attribute to an op. + kernel = dnnl_graph_op_attr_kernel, + /// Specifies an order attribute to an op. + order = dnnl_graph_op_attr_order, + /// Specifies an output_padding attribute to an op. + output_padding = dnnl_graph_op_attr_output_padding, + /// Specifies a pads_begin attribute to an op. + pads_begin = dnnl_graph_op_attr_pads_begin, + /// Specifies a pads_end attribute to an op. + pads_end = dnnl_graph_op_attr_pads_end, + /// Specifies a shape attribute to an op. + shape = dnnl_graph_op_attr_shape, + /// Specifies a sizes attribute to an op. + sizes = dnnl_graph_op_attr_sizes, + /// Specifies an src_shape attribute to an op. + src_shape = dnnl_graph_op_attr_src_shape, + /// Specifies a strides attribute to an op. + strides = dnnl_graph_op_attr_strides, + /// Specifies a weight_shape attribute to an op. + weights_shape = dnnl_graph_op_attr_weights_shape, + /// Specifies a zps attribute to an op. + zps = dnnl_graph_op_attr_zps, + /// Specifies the group shape of an op. The size of the vector should + /// match that of the input. For the dimensions where the grouped + /// quantization occurs, the values should correspond to the group + /// size, which indicates the number of elements that will share the + /// same scaling factor. + group_shape = dnnl_graph_op_attr_group_shape, + + // bool attributes. The value of these attributes can be any single bool + // value. + + /// Specifies an exclude_pad attribute to an op. + exclude_pad = dnnl_graph_op_attr_exclude_pad, + /// Specifies a keep_dims attribute to an op. + keep_dims = dnnl_graph_op_attr_keep_dims, + /// Specifies a keep_stats attribute to an op. + keep_stats = dnnl_graph_op_attr_keep_stats, + /// Specifies a per_channel_broadcast attribute to an op. + per_channel_broadcast = dnnl_graph_op_attr_per_channel_broadcast, + /// Specifies a special_zero attribute to an op. + special_zero = dnnl_graph_op_attr_special_zero, + /// Specifies a transpose_a attribute to an op. + transpose_a = dnnl_graph_op_attr_transpose_a, + /// Specifies a transpose_b attribute to an op. + transpose_b = dnnl_graph_op_attr_transpose_b, + /// Specifies an use_affine attribute to an op. + use_affine = dnnl_graph_op_attr_use_affine, + /// Specifies an use_dst attribute to an op. + use_dst = dnnl_graph_op_attr_use_dst, + + // string attributes. The value of these attributes can be a string. + + /// Specifies an auto_broadcast attribute to an op. The value can be + /// "none" or "numpy". + auto_broadcast = dnnl_graph_op_attr_auto_broadcast, + /// Specifies an auto_pad attribute to an op. The value can be "none", + /// "same_upper", "same_lower", or "valid". + auto_pad = dnnl_graph_op_attr_auto_pad, + /// Specifies an coordinate_transformation_mode attribute to an op. The + /// value can be "half_pixel" or "align_corners". The attribute is + /// defined for Interpolate operations. + coordinate_transformation_mode + = dnnl_graph_op_attr_coordinate_transformation_mode, + /// Specifies a data_format of an op. The value can be "NCX" or "NXC". + data_format = dnnl_graph_op_attr_data_format, + /// Specifies a mode attribute of an op. The value can be "nearest", + /// "linear", "bilinear", or "trilinear". The attribute is defined for + /// Interpolate operations. + mode = dnnl_graph_op_attr_mode, + /// Specifies a qtype attribute to an op. The value can be "per_channel" + /// or "per_tensor". The attribute is defined for quantization + /// operations. + qtype = dnnl_graph_op_attr_qtype, + /// Specifies a rounding_type attribute to an op. The value can be + /// "ceil" or "floor". + rounding_type = dnnl_graph_op_attr_rounding_type, + /// Specifies a weights_format of an op. The value can be "OIX", "XIO", + /// "IOX", or "XOI". Different operations may support different values. + weights_format = dnnl_graph_op_attr_weights_format, + + /// Specifies the end of all above exteral attributes for check. + end = dnnl_graph_op_attr_end, + }; + + /// Constructs an op object with an unique ID, an operation kind, and a name + /// string. + /// + /// @param id The unique ID of the op. + /// @param akind The op kind specifies which computation is represented by + /// the op, such as Convolution or ReLU. + /// @param verbose_name The string added as the op name. + op(size_t id, kind akind, const std::string &verbose_name = "") { + dnnl_graph_op_t op = nullptr; + error::wrap_c_api(dnnl_graph_op_create(&op, id, convert_to_c(akind), + verbose_name.c_str()), + "could not create op with id and op kind"); + reset(op); + } + + /// Constructs an op object with an unique ID, an operation kind, and + /// input/output logical tensors. + /// + /// @param id The unique ID of this op. + /// @param akind The op kind specifies which computation is represented by + /// this op, such as Convolution or ReLU. + /// @param inputs Input logical tensor to be bound to this op. + /// @param outputs Output logical tensor to be bound to this op. + /// @param verbose_name The string added as the op name. + op(size_t id, kind akind, const std::vector &inputs, + const std::vector &outputs, + const std::string &verbose_name = "") + : op(id, akind, verbose_name) { + for (const auto &input : inputs) { + error::wrap_c_api(dnnl_graph_op_add_input(get(), &(input.data)), + "adding input to the op failed"); + } + for (const auto &output : outputs) { + error::wrap_c_api(dnnl_graph_op_add_output(get(), &(output.data)), + "adding output to the op failed"); + } + } + + /// Adds an input logical tensor to the op. + /// + /// @param t Input logical tensor. + void add_input(const logical_tensor &t) { + error::wrap_c_api(dnnl_graph_op_add_input(get(), &(t.data)), + "adding input to the op failed"); + } + + /// Adds a vector of input logical tensors to the op. + /// + /// @param ts The list of input logical tensors. + void add_inputs(const std::vector &ts) { + for (const auto &t : ts) { + error::wrap_c_api(dnnl_graph_op_add_input(get(), &(t.data)), + "adding input to the op failed"); + } + } + + /// Adds an output logical tensor to the op. + /// + /// @param t Output logical tensor. + void add_output(const logical_tensor &t) { + error::wrap_c_api(dnnl_graph_op_add_output(get(), &(t.data)), + "adding output to the op failed"); + } + + /// Adds a vector of output logical tensors to the op. + /// + /// @param ts The list of output logical tensors. + void add_outputs(const std::vector &ts) { + for (const auto &t : ts) { + error::wrap_c_api(dnnl_graph_op_add_output(get(), &(t.data)), + "adding output to the op failed"); + } + } + + /// Sets the attribute according to the name and type (int64_t). + /// + /// @tparam Type_i Attribute's type. + /// @param name Attribute's name. + /// @param value The attribute's value. + /// @returns The Op self. + template ::value> = true> + op &set_attr(attr name, const Type_i &value) { + dnnl_graph_op_attr_t attr = convert_to_c(name); + error::wrap_c_api(dnnl_graph_op_set_attr_s64(get(), attr, &value, 1), + "could not set attribute to the op"); + return *this; + } + + /// Sets the attribute according to the name and type (float). + /// + /// @tparam Type_f Attribute's type. + /// @param name Attribute's name. + /// @param value The attribute's value. + /// @returns The Op self. + template ::value> = true> + op &set_attr(attr name, const Type_f &value) { + dnnl_graph_op_attr_t attr = convert_to_c(name); + error::wrap_c_api(dnnl_graph_op_set_attr_f32(get(), attr, &value, 1), + "could not set attribute to the op"); + return *this; + } + + /// Sets the attribute according to the name and type (bool). + /// + /// @tparam Type_b Attribute's type. + /// @param name Attribute's name. + /// @param value The attribute's value. + /// @returns The Op self. + template ::value> = true> + op &set_attr(attr name, const Type_b &value) { + dnnl_graph_op_attr_t attr = convert_to_c(name); + const uint8_t val = value; + error::wrap_c_api(dnnl_graph_op_set_attr_bool(get(), attr, &val, 1), + "could not set attribute to the op"); + return *this; + } + + /// Sets the attribute according to the name and type (string). + /// + /// @tparam Type_s Attribute's type. + /// @param name Attribute's name. + /// @param value The attribute's value. + /// @returns The Op self. + template ::value> = true> + op &set_attr(attr name, const Type_s &value) { + dnnl_graph_op_attr_t attr = convert_to_c(name); + error::wrap_c_api(dnnl_graph_op_set_attr_str( + get(), attr, value.c_str(), value.size()), + "could not set attribute to the op"); + return *this; + } + + /// Sets the attribute according to the name and type + /// (std::vector). + /// + /// @tparam Type_is Attribute's type. + /// @param name Attribute's name. + /// @param value The attribute's value. + /// @returns The Op self. + template >::value> = true> + op &set_attr(attr name, const Type_is &value) { + dnnl_graph_op_attr_t attr = convert_to_c(name); + error::wrap_c_api(dnnl_graph_op_set_attr_s64( + get(), attr, value.data(), value.size()), + "could not set attribute to the op"); + return *this; + } + + /// Sets the attribute according to the name and type (std::vector). + /// + /// @tparam Type_fs Attribute's type. + /// @param name Attribute's name. + /// @param value The attribute's value. + /// @returns The Op self. + template >::value> = true> + op &set_attr(attr name, const Type_fs &value) { + dnnl_graph_op_attr_t attr = convert_to_c(name); + error::wrap_c_api(dnnl_graph_op_set_attr_f32( + get(), attr, value.data(), value.size()), + "could not set attribute to the op"); + return *this; + } + +private: + dnnl_graph_op_kind_t convert_to_c(kind akind) { + return static_cast(akind); + } + + dnnl_graph_op_attr_t convert_to_c(attr aattr) { + return static_cast(aattr); + } +}; + +/// @} dnnl_graph_api_op + +/// @addtogroup dnnl_graph_api_partition Partition +/// +/// Partition represents a collection of operations and their input and output +/// logical tensors identified by library as the basic unit for compilation and +/// execution. +/// +/// @{ + +/// A partition object. +class partition : public partition_handle { +public: + /// Policy specifications for partitioning. + enum class policy { + /// Fusion policy returns partitions with typical post-op fusions, eg. + /// Convolution + ReLU or other element-wise operations or a chian of + /// post-ops. + fusion = dnnl_graph_partition_policy_fusion, + /// Debug policy doesn't not apply any fusions. It returns partitions + /// with single operations in each partition. The policy is useful when + /// users notice any bug or correctness issue in fusion policy. + debug = dnnl_graph_partition_policy_debug, + }; + + partition() = default; + + /// Constructs a partition object + /// + /// @param p A raw pointer to the C API handle + partition(dnnl_graph_partition_t p) { reset(p, false); } + + /// Creates a new partition with a given operator and engine kind. The API + /// is used to create a partition from an operation directly without + /// creating the graph and calling `get_partitions()`. The output partition + /// contains only one operation. + /// + /// @param aop An operation used to create the partition. + /// @param ekind Engine kind. + partition(const op &aop, engine::kind ekind) { + dnnl_graph_partition_t p = nullptr; + error::wrap_c_api(dnnl_graph_partition_create_with_op(&p, aop.get(), + static_cast(ekind)), + "could not create a partition with the op and engine kind"); + reset(p); + } + + /// Returns the number of operations contained in the partition. + /// + /// @returns Number of operations. + size_t get_ops_num() const { + size_t num {0}; + error::wrap_c_api(dnnl_graph_partition_get_op_num(get(), &num), + "could not get number of ops from the partition"); + return num; + } + + /// Returns all operation IDs contained in the partition. + /// + /// @returns An unordered set of operation IDs. + std::vector get_ops() const { + auto num = get_ops_num(); + std::vector ops(num); + + error::wrap_c_api(dnnl_graph_partition_get_ops(get(), num, ops.data()), + "could not get op ids from the partition"); + return ops; + } + + /// Returns the unique ID of the partition. Partition ID is generated by the + /// library internally. The ID can be used for debugging purpose or verbose. + /// + /// @returns ID of the partition. + size_t get_id() const { + size_t id {}; + error::wrap_c_api(dnnl_graph_partition_get_id(get(), &id), + "could not get id of the partition"); + return id; + } + + /// Compiles a partition with given input and output logical tensors. The + /// output logical tensors can contain unknown dimensions. For this case, + /// the compilation will deduce the output shapes according to input shapes. + /// The output logical tensors can also have layout type `any`. The + /// compilation will choose the optimal layout for output tensors. The + /// optimal layout will be represented as an opaque layout ID saved in the + /// output logical tensor. + /// + /// @param inputs A list of input logical tensors. + /// @param outputs A list of output logical tensors. + /// @param e The engine used to compile the partition. + /// @returns A compiled partition. + compiled_partition compile(const std::vector &inputs, + const std::vector &outputs, const engine &e) const { + if (!is_supported()) { + error::wrap_c_api(dnnl_invalid_arguments, + "could not compile an unsupported partition"); + } + + return compile_(inputs, outputs, e); + } + + /// Returns the supporting status of a partition. Some operations may not be + /// supported by the library under certain circumstances. During + /// partitioning stage, unsupported partitions will be returned to users + /// with each containing an unsupported operation. Users should check the + /// supporting status of a partition before transforming the computation + /// graph or compiling the partition. + /// + /// @returns @c true if this partition is supported or @c false if this + /// partition isn't supported by the library + bool is_supported() const { + uint8_t supported {0}; + error::wrap_c_api(dnnl_graph_partition_is_supported(get(), &supported), + "could not get supporting status of the partition"); + return supported != 0; + } + + /// Returns a list of input logical tensors from the partition. + /// + /// @returns A list of input logical tensors. + std::vector get_input_ports() const { + size_t num = 0; + error::wrap_c_api(dnnl_graph_partition_get_input_ports_num(get(), &num), + "could not get number of inputs of the partition"); + if (num == 0) return {}; + + std::vector c_inputs(num); + error::wrap_c_api(dnnl_graph_partition_get_input_ports( + get(), num, c_inputs.data()), + "could not get input logical tensors of the partition"); + + std::vector inputs; + inputs.reserve(num); + for (auto &c_lt : c_inputs) + inputs.emplace_back(c_lt); + return inputs; + } + + /// Returns a list of output logical tensors from the partition. + /// + /// @returns A list of output logical tensor. + std::vector get_output_ports() const { + size_t num = 0; + error::wrap_c_api( + dnnl_graph_partition_get_output_ports_num(get(), &num), + "cannot get number of outputs of the partition"); + if (num == 0) return {}; + + std::vector c_outputs(num); + error::wrap_c_api(dnnl_graph_partition_get_output_ports( + get(), num, c_outputs.data()), + "could not get output logical tensors of the partition"); + + std::vector outputs; + outputs.reserve(num); + for (auto &c_lt : c_outputs) + outputs.emplace_back(c_lt); + return outputs; + } + + /// Returns the engine kind of the partition + /// + /// @returns The engine kind + engine::kind get_engine_kind() const { + dnnl_engine_kind_t akind; + error::wrap_c_api(dnnl_graph_partition_get_engine_kind(get(), &akind), + "cannot get the engine kind from the partition"); + + return static_cast(akind); + } + +private: + compiled_partition compile_(const std::vector &inputs, + const std::vector &outputs, const engine &e) const { + std::vector c_inputs; + std::vector c_outputs; + + c_inputs.reserve(inputs.size()); + for (const auto &in : inputs) { + c_inputs.push_back(&(in.data)); + } + + c_outputs.reserve(outputs.size()); + for (const auto &out : outputs) { + c_outputs.push_back(&(out.data)); + } + + dnnl_graph_compiled_partition_t cpartitions = nullptr; + error::wrap_c_api( + dnnl_graph_compiled_partition_create(&cpartitions, get()), + "could not create compiled_partition"); + error::wrap_c_api(dnnl_graph_partition_compile(get(), cpartitions, + c_inputs.size(), c_inputs.data(), + c_outputs.size(), c_outputs.data(), e.get()), + "partition compile failed"); + + return compiled_partition(cpartitions); + } +}; + +/// @} dnnl_graph_api_partition + +/// @addtogroup dnnl_graph_api_graph Graph +/// +/// Graph represents a computational DAG with a set of operations. +/// #dnnl::graph::graph::add_op() adds an operation and its input and output +/// logical tensors into a graph. The library accumulates the operations and +/// logical tensors and constructs and validates the graph as an internal state. +/// A graph object is associated to a specific engine kind. The partitions +/// returned from the graph will inherit the engine kind of the graph. +/// +/// @{ + +/// A graph object. +class graph : public graph_handle { +public: + /// Constructs a graph with an engine kind. + /// + /// @param engine_kind Engine kind. + graph(engine::kind engine_kind) { + dnnl_graph_graph_t g = nullptr; + error::wrap_c_api( + dnnl_graph_graph_create(&g, convert_to_c(engine_kind)), + "could not create graph with engine kind"); + reset(g); + } + + /// Creates a new empty graph with an engine kind and a floating-point math + /// mode. All partitions returned from the graph will inherit the engine + /// kind and floating-point math mode. + /// + /// Setting the floating-point math mode enables automatic down-conversion + /// of inputs for the given graph, promoting speedup by using + /// lower-precision data types when available. + /// + /// @param engine_kind Engine kind. + /// @param mode Floating-point math mode. + graph(engine::kind engine_kind, fpmath_mode mode) { + dnnl_graph_graph_t g = nullptr; + error::wrap_c_api( + dnnl_graph_graph_create_with_fpmath_mode( + &g, convert_to_c(engine_kind), convert_to_c(mode)), + "could not create graph with engine kind and math mode"); + reset(g); + } + + /// Set the floating point math mode for a graph. Users can enforce the + /// graph to comply with the mode by specifying a boolean flag with the + /// setter function. + /// + /// @param mode The floating-point math mode. + /// @param apply_to_int The flag that controls whether to use + /// floating-point arithmetic for integral operations. + void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) { + error::wrap_c_api(dnnl_graph_graph_set_fpmath_mode( + get(), convert_to_c(mode), apply_to_int), + "could not set fpmath mode graph attribute"); + } + + /// Get the floating point math mode and the boolean flag that specifies + /// whether the graph will be enforced to comply the mode. + /// + /// @param mode The floating-point math mode. + /// @param apply_to_int The flag that controls whether to use + /// floating-point arithmetic for integral operations. + void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const { + dnnl_fpmath_mode_t c_mode; + int c_apply_to_int; + + error::wrap_c_api(dnnl_graph_graph_get_fpmath_mode( + get(), &c_mode, &c_apply_to_int), + "could not get fpmath mode graph attribute"); + + mode = fpmath_mode(c_mode); + apply_to_int = static_cast(c_apply_to_int); + } + + /// Adds an op into the graph to construct a computational DAG. The API will + /// return failure if the operator has already been added to the graph or + /// the operation cannot pass the schema check in the library (eg. input and + /// output numbers and data types, the attributes of the operation, etc.). + /// + /// @param op An operation to be added. + /// @param allow_exception A flag indicating whether the method is allowed + /// to throw an exception if it fails to add the op to the graph. + /// @returns #status::success or a status describing the error otherwise. + status add_op(const op &op, bool allow_exception = true) { + dnnl_status_t ret = dnnl_graph_add_op(get(), op.get()); + + if (allow_exception) { + error::wrap_c_api(ret, "could not add op to the graph"); + } + + return static_cast(ret); + } + + /// Finalizes a graph. It means users have finished adding operations into + /// the graph and the graph is ready for partitioning. Adding a new + /// operation into a finalized graph will return failures. Similarly, + /// partitioning on a un-finalized graph will also return failures. + void finalize() { + error::wrap_c_api(dnnl_graph_graph_finalize(get()), + "could not finalize the graph"); + } + + /// Checks if a graph is finalized. + /// + /// @return True if the graph is finalized or false if the graph is not + /// finalized. + bool is_finalized() const { + uint8_t ret = 0; + error::wrap_c_api(dnnl_graph_graph_is_finalized(get(), &ret), + "could not get the finalization status of the graph"); + + return ret != 0; + } + + /// Gets filtered partitions from a graph. Partitions will be claimed + /// internally according to the capability of the library, the engine kind + /// of the graph, and the policy. + /// + /// @param policy Partition policy, defaults to policy + /// #dnnl::graph::partition::policy::fusion. + /// @return A vector storing the partitions. + std::vector get_partitions( + partition::policy policy = partition::policy::fusion) { + if (!is_finalized()) { + error::wrap_c_api( + dnnl_invalid_graph, "the graph is not finalized yet"); + } + + error::wrap_c_api( + dnnl_graph_graph_filter(get(), + static_cast(policy)), + "could not filter the graph"); + + size_t num = 0; + error::wrap_c_api(dnnl_graph_graph_get_partition_num(get(), &num), + "could not get number of partitions from the graph"); + + // return early if there is no partitions in the graph. + if (num == 0) return {}; + + std::vector out_list; + out_list.reserve(num); + + std::vector partitions(num); + error::wrap_c_api( + dnnl_graph_graph_get_partitions(get(), num, partitions.data()), + "could not get partitions from the graph"); + + for (auto p : partitions) { + out_list.emplace_back(p); + } + + return out_list; + } + +private: + static dnnl_fpmath_mode_t convert_to_c(fpmath_mode mode) { + return static_cast(mode); + } + + static dnnl_engine_kind_t convert_to_c(engine::kind akind) { + return static_cast(akind); + } +}; + +/// @} dnnl_graph_api_graph + +/// @addtogroup dnnl_graph_api_compiled_partition_cache Compiled Partition Cache +/// +/// A set of functions that provide compiled partition cache control. +/// +/// @{ + +/// Returns the number of compiled partition that can be held in the compiled +/// partition cache at the same time. +inline int get_compiled_partition_cache_capacity() { + int result = 0; + error::wrap_c_api(dnnl_graph_get_compiled_partition_cache_capacity(&result), + "could not get compiled partition cache capacity"); + return result; +} + +/// @copydoc dnnl_graph_set_compiled_partition_cache_capacity(int capacity) +inline void set_compiled_partition_cache_capacity(int capacity) { + error::wrap_c_api( + dnnl_graph_set_compiled_partition_cache_capacity(capacity), + "could not set compiled partition cache capacity"); +} + +/// @} dnnl_graph_api_compiled_partition_cache + +/// @addtogroup dnnl_graph_api_constant_tensor_cache Constant Tensor Cache +/// +/// A set of functions that provide constant tensor cache control +/// +/// @{ + +/// Control the enabling or disabling of constant tensor cache. This API must be +/// called once before compilation stage. By default, constant tensor cache is +/// disabled in the library. +/// @note This API is deprecated and will be removed in future release, please +/// use the set_constant_tensor_cache_capacity API to disable +/// constant tensor cache by setting it's capacity to zero. +/// +/// @param flag Set to positive value to enable the cache and set to 0 to +/// disable the cache. Negative values are invalid. +inline void set_constant_tensor_cache(int flag) { + error::wrap_c_api(dnnl_graph_set_constant_tensor_cache(flag), + "fail to set constant tensor cache"); +} + +/// Return the enabling status of constant tensor cache. +/// @note This API is deprecated and will be removed in future release, please +/// use the get_constant_tensor_cache_capacity API to check the +/// enabling status by checking it's capacity. +inline int get_constant_tensor_cache() { + int result = 0; + error::wrap_c_api(dnnl_graph_get_constant_tensor_cache(&result), + "fail to get constant tensor cache"); + return result; +} + +/// Control the capacity for the constant tensor cache that used for specific +/// engine kind. This API is thread safe and can be called multiple times at +/// runtime. The capacity is set to zero by default which means the cache is +/// disabled. When calling this API, the corresponding cache will be flushed. +/// Setting capacity to 0 means to clear all cached tensors and disable cache. +/// Once the capacity limit is reached, no new tensors will be cached. If there +/// are multiple devices for an engine kind, the capacity set here is for each +/// device. +/// +/// @param kind The engine kind that the constant tensor cache used for. +/// @param size The constant tensor cache capacity size to set. +inline void set_constant_tensor_cache_capacity(engine::kind kind, size_t size) { + error::wrap_c_api(dnnl_graph_set_constant_tensor_cache_capacity( + static_cast(kind), size), + "fail to set constant tensor cache capacity"); +} + +/// Return the current capacity of constant tensor cache. +/// +/// @param kind The engine kind that the constant tensor cache used for. +inline size_t get_constant_tensor_cache_capacity(engine::kind kind) { + size_t size = 0; + error::wrap_c_api(dnnl_graph_get_constant_tensor_cache_capacity( + static_cast(kind), &size), + "fail to get constant tensor cache capacity"); + return size; +} + +/// @} dnnl_graph_api_constant_tensor_cache + +} // namespace graph + +/// @} dnnl_graph_api + +} // namespace dnnl + +/// @cond DO_NOT_DOCUMENT_THIS + +/// oneAPI namespace +// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace. +namespace oneapi { +// Note: without this guard, doxygen warns of potentially recursive namespace +#ifndef DOXYGEN_SHOULD_SKIP_THIS +/// oneDNN alias namespace +namespace dnnl = ::dnnl; +#endif +} // namespace oneapi + +/// @endcond + +/// @} dnnl_api + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_ocl.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_ocl.h new file mode 100644 index 0000000000000000000000000000000000000000..81530fc7a37197ed2778d4422549420bb708fa37 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_ocl.h @@ -0,0 +1,149 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_GRAPH_OCL_H +#define ONEAPI_DNNL_DNNL_GRAPH_OCL_H + +#include "oneapi/dnnl/dnnl_graph.h" + +/// @cond DO_NOT_DOCUMENT_THIS +// Set target version for OpenCL explicitly to suppress a compiler warning. +#ifndef CL_TARGET_OPENCL_VERSION +#define CL_TARGET_OPENCL_VERSION 120 +#endif + +#include +/// @endcond + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_graph_api +/// @{ + +/// @addtogroup dnnl_graph_api_interop +/// @{ + +/// @addtogroup dnnl_graph_api_ocl_interop +/// @{ + +/// Allocation call-back function interface for OpenCL. OpenCL allocator should +/// be used for OpenCL GPU runtime. The call-back should return a USM device +/// memory pointer. +/// +/// @param size Memory size in bytes for requested allocation +/// @param alignment The minimum alignment in bytes for the requested allocation +/// @param device A valid OpenCL device used to allocate +/// @param context A valid OpenCL context used to allocate +/// @returns The memory address of the requested USM allocation. +typedef void *(*dnnl_graph_ocl_allocate_f)( + size_t size, size_t alignment, cl_device_id device, cl_context context); + +/// Deallocation call-back function interface for OpenCL. OpenCL allocator +/// should be used for OpenCL runtime. The call-back should deallocate a USM +/// device memory returned by #dnnl_graph_ocl_allocate_f. The event should be +/// completed before deallocate the USM. +/// +/// @param buf The USM allocation to be released +/// @param device A valid OpenCL device the USM associated with +/// @param context A valid OpenCL context used to free the USM allocation +/// @param event A event which the USM deallocation depends on +typedef void (*dnnl_graph_ocl_deallocate_f)( + void *buf, cl_device_id device, cl_context context, cl_event event); + +/// Creates an allocator with the given allocation and deallocation call-back +/// function pointers. +/// +/// @param allocator Output allocator +/// @param ocl_malloc A pointer to OpenCL malloc function +/// @param ocl_free A pointer to OpenCL free function +/// @returns #dnnl_success on success and a status describing the +/// error otherwise. +dnnl_status_t DNNL_API dnnl_graph_ocl_interop_allocator_create( + dnnl_graph_allocator_t *allocator, dnnl_graph_ocl_allocate_f ocl_malloc, + dnnl_graph_ocl_deallocate_f ocl_free); + +/// This API is a supplement for existing oneDNN engine API: +/// dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create( +/// dnnl_engine_t *engine, cl_device_id device, cl_context context); +/// +/// @param engine Output engine. +/// @param device Underlying OpenCL device to use for the engine. +/// @param context Underlying OpenCL context to use for the engine. +/// @param alloc Underlying allocator to use for the engine. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_graph_ocl_interop_make_engine_with_allocator( + dnnl_engine_t *engine, cl_device_id device, cl_context context, + const_dnnl_graph_allocator_t alloc); + +/// This API is a supplement for existing oneDNN engine API: +/// dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create_from_cache_blob( +/// dnnl_engine_t *engine, cl_device_id device, cl_context context, +/// size_t size, const uint8_t *cache_blob); +/// +/// @param engine Output engine. +/// @param device The OpenCL device that this engine will encapsulate. +/// @param context The OpenCL context (containing the device) that this +/// engine will use for all operations. +/// @param alloc Underlying allocator to use for the engine. +/// @param size Size of the cache blob in bytes. +/// @param cache_blob Cache blob of size @p size. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API +dnnl_graph_ocl_interop_make_engine_from_cache_blob_with_allocator( + dnnl_engine_t *engine, cl_device_id device, cl_context context, + const_dnnl_graph_allocator_t alloc, size_t size, + const uint8_t *cache_blob); + +/// Execute a compiled partition with OpenCL runtime. +/// +/// @param compiled_partition The handle of target compiled_partition. +/// @param stream The stream used for execution +/// @param num_inputs The number of input tensors +/// @param inputs A list of input tensors +/// @param num_outputs The number of output tensors +/// @param outputs A non-empty list of output tensors +/// @param deps Optional handle of list with `cl_event` dependencies. +/// @param ndeps Number of dependencies. +/// @param return_event The handle of cl_event. +/// @returns #dnnl_success on success and a status describing the +/// error otherwise. +dnnl_status_t DNNL_API dnnl_graph_ocl_interop_compiled_partition_execute( + const_dnnl_graph_compiled_partition_t compiled_partition, + dnnl_stream_t stream, size_t num_inputs, + const_dnnl_graph_tensor_t *inputs, size_t num_outputs, + const_dnnl_graph_tensor_t *outputs, const cl_event *deps, int ndeps, + cl_event *return_event); + +/// @} dnnl_graph_api_ocl_interop + +/// @} dnnl_graph_api_interop + +/// @} dnnl_graph_api + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_ocl.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_ocl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..18ff36bd686925ec7d09257632b9f18f56247e0d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_ocl.hpp @@ -0,0 +1,156 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// Graph OpenCL interop API + +#ifndef ONEAPI_DNNL_DNNL_GRAPH_OCL_HPP +#define ONEAPI_DNNL_DNNL_GRAPH_OCL_HPP + +/// @cond DO_NOT_DOCUMENT_THIS +#include + +#include + +#include "oneapi/dnnl/dnnl_graph.hpp" +#include "oneapi/dnnl/dnnl_graph_ocl.h" +#include "oneapi/dnnl/dnnl_ocl.hpp" +/// @endcond + +/// @addtogroup dnnl_api +/// @{ + +namespace dnnl { + +/// @addtogroup dnnl_graph_api +/// @{ + +namespace graph { + +/// @addtogroup dnnl_graph_api_interop Runtime interoperability API +/// API extensions to interact with the underlying run-time. +/// @{ + +/// @addtogroup dnnl_graph_api_ocl_interop OpenCL interoperability API +/// API extensions to interact with the underlying OpenCL run-time. +/// @{ + +/// OpenCL interoperability namespace +namespace ocl_interop { + +/// Constructs an allocator from OpenCL malloc and free function pointer. OpenCL +/// allocator should be used for OpenCL GPU runtime. Currently, only device USM +/// allocator is supported. +/// +/// @param ocl_malloc The pointer to OpenCL malloc function +/// @param ocl_free The pointer to OpenCL free function +/// @returns Created allocator +inline allocator make_allocator(dnnl_graph_ocl_allocate_f ocl_malloc, + dnnl_graph_ocl_deallocate_f ocl_free) { + dnnl_graph_allocator_t c_allocator = nullptr; + error::wrap_c_api(dnnl_graph_ocl_interop_allocator_create( + &c_allocator, ocl_malloc, ocl_free), + "could not create allocator for opencl device"); + return allocator(c_allocator); +} + +/// Constructs an engine from an OpenCL device, an OpenCL context, and an +/// allocator. +/// +/// @param device A valid OpenCL device to construct the engine +/// @param context A valid OpenCL context to construct the engine +/// @param alloc An allocator to associate with the engine +/// @returns Created engine +inline engine make_engine_with_allocator( + cl_device_id device, cl_context context, const allocator &alloc) { + dnnl_engine_t c_engine; + error::wrap_c_api(dnnl_graph_ocl_interop_make_engine_with_allocator( + &c_engine, device, context, alloc.get()), + "could not make an engine with allocator"); + return engine(c_engine); +} + +/// Constructs an engine from an OpenCL device, an OpenCL context, an +/// allocator, and a serialized engine cache blob. +/// +/// @param device A valid OpenCL device to construct the engine +/// @param context A valid OpenCL context to construct the engine +/// @param alloc An allocator to associate with the engine +/// @param cache_blob Cache blob serialized beforehand +/// @returns Created engine +inline engine make_engine_with_allocator(cl_device_id device, + cl_context context, const allocator &alloc, + const std::vector &cache_blob) { + dnnl_engine_t c_engine; + error::wrap_c_api( + dnnl_graph_ocl_interop_make_engine_from_cache_blob_with_allocator( + &c_engine, device, context, alloc.get(), cache_blob.size(), + cache_blob.data()), + "could not make an engine with allocator from cache blob"); + return engine(c_engine); +} + +/// Executes a compiled partition in a specified stream and returns a OpenCL +/// event. +/// +/// @param c_partition Compiled partition to execute. +/// @param astream Stream object to run over +/// @param inputs Arguments map. +/// @param outputs Arguments map. +/// @param deps Optional vector with `cl_event` dependencies. +/// @returns Output event. +inline cl_event execute(compiled_partition &c_partition, stream &astream, + const std::vector &inputs, std::vector &outputs, + const std::vector &deps = {}) { + std::vector c_inputs; + c_inputs.reserve(inputs.size()); + for (auto &in : inputs) { + c_inputs.push_back(in.get()); + } + std::vector c_outputs; + c_outputs.reserve(outputs.size()); + for (auto &out : outputs) { + c_outputs.push_back(out.get()); + } + + const cl_event *c_deps = deps.empty() ? nullptr : deps.data(); + + cl_event ocl_event; + error::wrap_c_api( + dnnl_graph_ocl_interop_compiled_partition_execute(c_partition.get(), + astream.get(), c_inputs.size(), c_inputs.data(), + c_outputs.size(), c_outputs.data(), c_deps, + (int)deps.size(), &ocl_event), + "could not execute the compiled_partition on a specified opencl " + "stream"); + return ocl_event; +} + +} // namespace ocl_interop + +/// @} dnnl_graph_api_ocl_interop + +/// @} dnnl_graph_api_interop + +} // namespace graph + +/// @} dnnl_graph_api + +} // namespace dnnl + +/// @} dnnl_api + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.h new file mode 100644 index 0000000000000000000000000000000000000000..518b94e4739c5ba97807ef34bdbd792315746382 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.h @@ -0,0 +1,99 @@ +/******************************************************************************* +* Copyright 2020-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_GRAPH_SYCL_H +#define ONEAPI_DNNL_DNNL_GRAPH_SYCL_H + +#include "oneapi/dnnl/dnnl_graph.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_graph_api +/// @{ + +/// @addtogroup dnnl_graph_api_interop +/// @{ + +/// @addtogroup dnnl_graph_api_sycl_interop +/// @{ + +/// Allocation call-back function interface for SYCL. SYCL allocator should be +/// used for SYCL runtime and host allocator should be used for non-SYCL. The +/// call-back should return a USM device memory pointer. +typedef void *(*dnnl_graph_sycl_allocate_f)( + size_t size, size_t alignment, const void *dev, const void *context); + +/// Deallocation call-back function interface for SYCL. SYCL allocator should be +/// used for SYCL runtime and host allocator should be used for non-SYCL. The +/// call-back should deallocate a USM device memory returned by +/// #dnnl_graph_sycl_allocate_f. +typedef void (*dnnl_graph_sycl_deallocate_f)( + void *buf, const void *dev, const void *context, void *event); + +/// Creates an allocator with the given allocation and deallocation call-back +/// function pointers. +/// +/// @param allocator Output allocator +/// @param sycl_malloc A pointer to SYCL malloc function +/// @param sycl_free A pointer to SYCL free function +/// @returns #dnnl_success on success and a status describing the +/// error otherwise. +dnnl_status_t DNNL_API dnnl_graph_sycl_interop_allocator_create( + dnnl_graph_allocator_t *allocator, + dnnl_graph_sycl_allocate_f sycl_malloc, + dnnl_graph_sycl_deallocate_f sycl_free); + +/// This API is a supplement for existing onednn engine API. +dnnl_status_t DNNL_API dnnl_graph_sycl_interop_make_engine_with_allocator( + dnnl_engine_t *engine, const void *device, const void *context, + const_dnnl_graph_allocator_t alloc); + +/// Execute a compiled partition with sycl runtime. +/// +/// @param compiled_partition The handle of target compiled_partition. +/// @param stream The stream used for execution +/// @param num_inputs The number of input tensors +/// @param inputs A list of input tensors +/// @param num_outputs The number of output tensors +/// @param outputs A non-empty list of output tensors +/// @param deps Optional handle of list with `sycl::event` dependencies. +/// @param sycl_event The handle of sycl event. +/// @returns #dnnl_success on success and a status describing the +/// error otherwise. +dnnl_status_t DNNL_API dnnl_graph_sycl_interop_compiled_partition_execute( + const_dnnl_graph_compiled_partition_t compiled_partition, + dnnl_stream_t stream, size_t num_inputs, + const_dnnl_graph_tensor_t *inputs, size_t num_outputs, + const_dnnl_graph_tensor_t *outputs, const void *deps, void *sycl_event); + +/// @} dnnl_graph_api_sycl_interop + +/// @} dnnl_graph_api_interop + +/// @} dnnl_graph_api + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2507842cb38b57211cc249a205afd2a6fa2fc8bb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.hpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2020-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// Graph SYCL interop API + +#ifndef ONEAPI_DNNL_DNNL_GRAPH_SYCL_HPP +#define ONEAPI_DNNL_DNNL_GRAPH_SYCL_HPP + +/// @cond DO_NOT_DOCUMENT_THIS +#include + +#if __has_include() +#include +#else +#error "Unsupported compiler" +#endif + +#include "oneapi/dnnl/dnnl_graph.hpp" +#include "oneapi/dnnl/dnnl_graph_sycl.h" +/// @endcond + +/// @addtogroup dnnl_api +/// @{ + +namespace dnnl { + +/// @addtogroup dnnl_graph_api +/// @{ + +namespace graph { + +/// @addtogroup dnnl_graph_api_interop Runtime interoperability API +/// API extensions to interact with the underlying run-time. +/// @{ + +/// @addtogroup dnnl_graph_api_sycl_interop SYCL interoperability API +/// API extensions to interact with the underlying SYCL run-time. +/// @{ + +/// SYCL interoperability namespace +namespace sycl_interop { + +/// Constructs an allocator from SYCL malloc and free function pointer. SYCL +/// allocator should be used for SYCL runtime and host allocator should be used +/// for non-SYCL. Currently, only device USM allocator is supported. +/// +/// @param sycl_malloc The pointer to SYCL malloc function +/// @param sycl_free The pointer to SYCL free function +/// @returns Created allocator +inline allocator make_allocator(dnnl_graph_sycl_allocate_f sycl_malloc, + dnnl_graph_sycl_deallocate_f sycl_free) { + dnnl_graph_allocator_t c_allocator = nullptr; + error::wrap_c_api(dnnl_graph_sycl_interop_allocator_create( + &c_allocator, sycl_malloc, sycl_free), + "could not create allocator for sycl device"); + return allocator(c_allocator); +} + +inline engine make_engine_with_allocator(const sycl::device &adevice, + const sycl::context &acontext, const allocator &alloc) { + dnnl_engine_t c_engine; + error::wrap_c_api( + dnnl_graph_sycl_interop_make_engine_with_allocator(&c_engine, + static_cast(&adevice), + static_cast(&acontext), alloc.get()), + "could not make an engine with allocator"); + return engine(c_engine); +} + +/// Executes a compiled partition in a specified stream and returns a SYCL +/// event. +/// +/// @param c_partition Compiled partition to execute. +/// @param astream Stream object to run over +/// @param inputs Arguments map. +/// @param outputs Arguments map. +/// @param deps Optional vector with `sycl::event` dependencies. +/// @returns Output event. +inline sycl::event execute(compiled_partition &c_partition, stream &astream, + const std::vector &inputs, std::vector &outputs, + const std::vector &deps = {}) { + std::vector c_inputs; + c_inputs.reserve(inputs.size()); + for (auto &in : inputs) { + c_inputs.push_back(in.get()); + } + std::vector c_outputs; + c_outputs.reserve(outputs.size()); + for (auto &out : outputs) { + c_outputs.push_back(out.get()); + } + + sycl::event sycl_event; + error::wrap_c_api(dnnl_graph_sycl_interop_compiled_partition_execute( + c_partition.get(), astream.get(), c_inputs.size(), + c_inputs.data(), c_outputs.size(), + c_outputs.data(), &deps, &sycl_event), + "could not execute the compiled_partition on a specified sycl " + "stream"); + return sycl_event; +} + +} // namespace sycl_interop + +/// @} dnnl_graph_api_sycl_interop + +/// @} dnnl_graph_api_interop + +} // namespace graph + +/// @} dnnl_graph_api + +} // namespace dnnl + +/// @} dnnl_api + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_types.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_types.h new file mode 100644 index 0000000000000000000000000000000000000000..4aeb4d6bd87f888f066f58b1281c4e3f0219c957 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_graph_types.h @@ -0,0 +1,475 @@ +/******************************************************************************* + * Copyright 2020-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C API definitions + +#ifndef ONEAPI_DNNL_DNNL_GRAPH_TYPES_H +#define ONEAPI_DNNL_DNNL_GRAPH_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +/// @cond DO_NOT_DOCUMENT_THIS +#include +#include + +#include "oneapi/dnnl/dnnl_common_types.h" +/// @endcond + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_graph_api +/// @{ + +/// @addtogroup dnnl_graph_api_logical_tensor +/// @{ + +/// A wildcard value for number of dimensions which is unknown at a tensor or +/// operation creation time. +#define DNNL_GRAPH_UNKNOWN_NDIMS -1 + +/// A wildcard value for dimensions that are unknown at a tensor or operation +/// creation time. +#define DNNL_GRAPH_UNKNOWN_DIM INT64_MIN + +/// Layout type specification +typedef enum { + /// Undefined layout type + dnnl_graph_layout_type_undef = 0, + /// Any means to let the library to decide the layout for a tensor during + /// partition compilation. + dnnl_graph_layout_type_any = 1, + /// Strided means that the layout of a tensor is determined by the strides + /// field in the logical tensor. + dnnl_graph_layout_type_strided = 2, + /// Opaque means that the layout of a tensor is the library specific. + /// Usually, an opaque layout is generated by a partition which is compiled + /// with layout type any. + dnnl_graph_layout_type_opaque = 3, +} dnnl_graph_layout_type_t; + +/// Logical tensor property +typedef enum { + /// Undefined tensor property + dnnl_graph_tensor_property_undef = 0, + /// Variable means the tensor may be changed during computation or between + /// different iterations. + dnnl_graph_tensor_property_variable = 1, + /// Constant means the tensor will keep unchanged during computation and + /// between different iterations. It's useful for the library to apply + /// optimizations for constant tensors or cache constant tensors inside the + /// library. For example, constant weight tensors in inference scenarios. + dnnl_graph_tensor_property_constant = 2, +} dnnl_graph_tensor_property_t; + +/// Logical tensor. It is based on an ID, a number of dimensions, dimensions +/// themselves, element data type, tensor property and tensor memory layout. +typedef struct { + /// Unique id of each logical tensor. The library uses logical tensor IDs to + /// build up the connections between operations if the output of one + /// operation has the same ID as the input of another operation. + size_t id; + + /// Number of dimensions. -1 means unknown (DNNL_GRAPH_UNKNOWN_NDIMS). 0 is + /// used to define scalar tensor. + int ndims; + + /// Size of each dimension. #DNNL_GRAPH_UNKNOWN_DIM means the size of that + /// dimension is unknown. 0 is used to define zero-dimension tensor. The + /// library supports to deduce output shapes according to input shapes + /// during compilation. Unlike memory descriptor in oneDNN primitive API, + /// the order of dimensions is not defined in logical tensor. It is defined + /// by the operations which respect the order through the attributes + /// #dnnl_graph_op_attr_data_format or #dnnl_graph_op_attr_weights_format. + /// For example, for a Convolution with `data_format=NXC`, it means the + /// first element of dims of activation tensor is mini-batch size, the last + /// effective element of dims is channel size, and other elements between + /// them are spatial dimensions. + dnnl_dims_t dims; + + /// Data type of the tensor elements. + dnnl_data_type_t data_type; + + /// Property type of the tensor. + dnnl_graph_tensor_property_t property; + + /// Layout type of the tensor. + dnnl_graph_layout_type_t layout_type; + union { + /// The field is valid when `layout_type` is + /// #dnnl_graph_layout_type_strided. #DNNL_GRAPH_UNKNOWN_DIM means the + /// stride of the dimension is unknown. The library currently doesn't + /// support other negative stride values. + dnnl_dims_t strides; + + /// The field is valid when `layout_type` is + /// #dnnl_graph_layout_type_opaque. An opaque layout ID is usually + /// generated by a partition which is compiled with layout type any. + size_t layout_id; + } layout; +} dnnl_graph_logical_tensor_t; + +/// @} dnnl_graph_api_logical_tensor + +/// @addtogroup dnnl_graph_api_partition +/// @{ + +/// Policy specifications for partitioning +typedef enum { + /// Fusion policy returns partitions with typical post-op fusions, eg. + /// Convolution + ReLU or other element-wise operations or a chian of + /// post-ops. + dnnl_graph_partition_policy_fusion = 1, + /// Debug policy doesn't not apply any fusions. It returns partitions with + /// single operation in each partition. The policy is useful when users + /// notice any bug or correctness issue in fusion policy. + dnnl_graph_partition_policy_debug = 2, +} dnnl_graph_partition_policy_t; + +/// An opaque structure to describe a partition. +struct dnnl_graph_partition; + +/// A partition handle. +typedef struct dnnl_graph_partition *dnnl_graph_partition_t; + +/// A constant partition handle. +typedef const struct dnnl_graph_partition *const_dnnl_graph_partition_t; + +/// @} dnnl_graph_api_partition + +/// @addtogroup dnnl_graph_api_graph +/// @{ + +/// An opaque structure to describe a graph. +struct dnnl_graph_graph; + +/// A graph handle. +typedef struct dnnl_graph_graph *dnnl_graph_graph_t; + +/// A constant graph handle. +typedef const struct dnnl_graph_graph *const_dnnl_graph_graph_t; + +/// @} dnnl_graph_api_graph + +/// @addtogroup dnnl_graph_api_op +/// @{ + +/// Kinds of operations +typedef enum { + dnnl_graph_op_abs, + dnnl_graph_op_abs_backward, + dnnl_graph_op_add, + dnnl_graph_op_avg_pool, + dnnl_graph_op_avg_pool_backward, + dnnl_graph_op_batch_norm_backward, + dnnl_graph_op_batch_norm_forward_training, + dnnl_graph_op_batch_norm_inference, + dnnl_graph_op_bias_add, + dnnl_graph_op_bias_add_backward, + dnnl_graph_op_clamp, + dnnl_graph_op_clamp_backward, + dnnl_graph_op_concat, + dnnl_graph_op_convolution, + dnnl_graph_op_convolution_backward_data, + dnnl_graph_op_convolution_backward_weights, + dnnl_graph_op_conv_transpose, + dnnl_graph_op_conv_transpose_backward_data, + dnnl_graph_op_conv_transpose_backward_weights, + dnnl_graph_op_dequantize, + dnnl_graph_op_divide, + dnnl_graph_op_dynamic_dequantize, + dnnl_graph_op_dynamic_quantize, + dnnl_graph_op_elu, + dnnl_graph_op_elu_backward, + dnnl_graph_op_end, + dnnl_graph_op_exp, + dnnl_graph_op_gelu, + dnnl_graph_op_gelu_backward, + dnnl_graph_op_hard_swish, + dnnl_graph_op_hard_swish_backward, + dnnl_graph_op_interpolate, + dnnl_graph_op_interpolate_backward, + dnnl_graph_op_layer_norm, + dnnl_graph_op_layer_norm_backward, + dnnl_graph_op_leaky_relu, + dnnl_graph_op_log, + dnnl_graph_op_log_softmax, + dnnl_graph_op_log_softmax_backward, + dnnl_graph_op_matmul, + dnnl_graph_op_maximum, + dnnl_graph_op_max_pool, + dnnl_graph_op_max_pool_backward, + dnnl_graph_op_minimum, + dnnl_graph_op_mish, + dnnl_graph_op_mish_backward, + dnnl_graph_op_multiply, + dnnl_graph_op_prelu, + dnnl_graph_op_prelu_backward, + dnnl_graph_op_quantize, + dnnl_graph_op_reciprocal, + dnnl_graph_op_reduce_l1, + dnnl_graph_op_reduce_l2, + dnnl_graph_op_reduce_max, + dnnl_graph_op_reduce_mean, + dnnl_graph_op_reduce_min, + dnnl_graph_op_reduce_prod, + dnnl_graph_op_reduce_sum, + dnnl_graph_op_relu, + dnnl_graph_op_relu_backward, + dnnl_graph_op_reorder, + dnnl_graph_op_round, + dnnl_graph_op_sigmoid, + dnnl_graph_op_sigmoid_backward, + dnnl_graph_op_softmax, + dnnl_graph_op_softmax_backward, + dnnl_graph_op_softplus, + dnnl_graph_op_softplus_backward, + dnnl_graph_op_sqrt, + dnnl_graph_op_sqrt_backward, + dnnl_graph_op_square, + dnnl_graph_op_squared_difference, + dnnl_graph_op_static_reshape, + dnnl_graph_op_static_transpose, + dnnl_graph_op_subtract, + dnnl_graph_op_tanh, + dnnl_graph_op_tanh_backward, + dnnl_graph_op_type_cast, + dnnl_graph_op_wildcard, + dnnl_graph_op_hard_sigmoid, + dnnl_graph_op_hard_sigmoid_backward, + dnnl_graph_op_select, + dnnl_graph_op_pow, + dnnl_graph_op_group_norm, + dnnl_graph_op_gen_index, + dnnl_graph_op_greater_equal, + dnnl_graph_op_last_symbol, +} dnnl_graph_op_kind_t; + +/// Attributes of operations +typedef enum { + /// Undefined op attribute. + dnnl_graph_op_attr_undef = 0, + + // float32 attributes. The value of these attributes can be any single + // float32 number. + + /// Specifies an alpha attribute to an op. + dnnl_graph_op_attr_alpha = 0x1, + /// Specifies an beta attribute to an op. + dnnl_graph_op_attr_beta, + /// Specifies an epsilon attribute to an op. + dnnl_graph_op_attr_epsilon, + /// Specifies a max attribute to an op. + dnnl_graph_op_attr_max, + ///Specifies a min attribute to an op. + dnnl_graph_op_attr_min, + /// Specifies a momentum attribute to an op. + dnnl_graph_op_attr_momentum, + + // float32 vector attributes. The value of these attributes can be a vector + // of float32 numbers. + + /// Specifies a scales attribute to an op. + dnnl_graph_op_attr_scales = 0x20, + + // int64_t attributes. The value of these attributes can be any single int64 + // number. + + /// Specifies an axis attribute to an op. + dnnl_graph_op_attr_axis = 0x30, + /// Specifies a begin_norm_axis attribute to an op. + dnnl_graph_op_attr_begin_norm_axis, + /// Specifies a groups attribute to an op. + dnnl_graph_op_attr_groups, + + // int64_t vector attributes. The value of these attributes can be a vector + // of int64 numbers. + + /// Specifies an axes attribute to an op. + dnnl_graph_op_attr_axes = 0x40, + /// Specifies a dilations attribute to an op. + dnnl_graph_op_attr_dilations, + /// Specifies an dst_shape attribute to an op. + dnnl_graph_op_attr_dst_shape, + /// Specifies a kernel attribute to an op. + dnnl_graph_op_attr_kernel, + /// Specifies an order attribute to an op. + dnnl_graph_op_attr_order, + /// Specifies an output_padding attribute to an op. + dnnl_graph_op_attr_output_padding, + /// Specifies a pads_begin attribute to an op. + dnnl_graph_op_attr_pads_begin, + /// Specifies a pads_end attribute to an op. + dnnl_graph_op_attr_pads_end, + /// Specifies a shape attribute to an op. + dnnl_graph_op_attr_shape, + /// Specifies a sizes attribute to an op. + dnnl_graph_op_attr_sizes, + /// Specifies a input_shape attribute to an op. + dnnl_graph_op_attr_src_shape, + /// Specifies a strides attribute to an op. + dnnl_graph_op_attr_strides, + /// Specifies a weight_shape attribute to an op. + dnnl_graph_op_attr_weights_shape, + /// Specifies a zps attribute to an op. + dnnl_graph_op_attr_zps, + /// Specifies a group shape attribute to an op. + dnnl_graph_op_attr_group_shape, + + // bool attributes. The value of these attributes can be any single bool + // value. + + /// Specifies an exclude_pad attribute to an op. + dnnl_graph_op_attr_exclude_pad = 0x60, + /// Specifies a keep_dims attribute to an op. + dnnl_graph_op_attr_keep_dims, + /// Specifies a keep_stats attribute to an op. + dnnl_graph_op_attr_keep_stats, + /// Specifies a per_channel_broadcast attribute to an op. + dnnl_graph_op_attr_per_channel_broadcast, + /// Specifies a special_zero attribute to an op. + dnnl_graph_op_attr_special_zero, + /// Specifies a transpose_a attribute to an op. + dnnl_graph_op_attr_transpose_a, + /// Specifies a transpose_b attribute to an op. + dnnl_graph_op_attr_transpose_b, + /// Specifies an use_affine attribute to an op. + dnnl_graph_op_attr_use_affine, + /// Specifies an use_dst attribute to an op. + dnnl_graph_op_attr_use_dst, + + // string attributes. The value of these attributes can be a string. + + /// Specifies an auto_broadcast attribute to an op. The value can be "none" + /// or "numpy". + dnnl_graph_op_attr_auto_broadcast = 0x80, + /// Specifies an auto_pad attribute to an op. The value can be "none", + /// "same_upper", "same_lower", or "valid". + dnnl_graph_op_attr_auto_pad, + /// Specifies an coordinate_transformation_mode attribute to an op. The + /// value can be "half_pixel" or "align_corners". The attribute is defined + /// for Interpolate operations. + dnnl_graph_op_attr_coordinate_transformation_mode, + /// Specifies a data_format of an op. The value can be "NCX" or "NXC". + dnnl_graph_op_attr_data_format, + /// Specifies a mode attribute of an op. The value can be "nearest", + /// "linear", "bilinear", or "trilinear". The attribute is defined for + /// Interpolate operations. + dnnl_graph_op_attr_mode, + /// Specifies a qtype attribute to an op. The value can be "per_channel" or + /// "per_tensor". The attribute is defined for quantization operations. + dnnl_graph_op_attr_qtype, + /// Specifies a rounding_type attribute to an op. The value can be "ceil" or + /// "floor". + dnnl_graph_op_attr_rounding_type, + /// Specifies a weights_format of an op. The value can be "OIX", "XIO", + /// "IOX", or "XOI". Different operations may support different values. + dnnl_graph_op_attr_weights_format, + + /// Specifies the end of all above exteral attributes for check. + dnnl_graph_op_attr_end = 0xFF, +} dnnl_graph_op_attr_t; + +/// An opaque structure to describe an operation. +struct dnnl_graph_op; + +/// An operation handle. +typedef struct dnnl_graph_op *dnnl_graph_op_t; + +/// A constant operation handle. +typedef const struct dnnl_graph_op *const_dnnl_graph_op_t; + +/// @} dnnl_graph_api_op + +/// @addtogroup dnnl_graph_api_allocator +/// @{ + +/// Allocation call-back function interface for host. For SYCL allocator, see +/// #dnnl_graph_sycl_allocate_f. +typedef void *(*dnnl_graph_host_allocate_f)(size_t size, size_t alignment); + +/// Deallocation call-back function interface for host. For SYCL allocator, see +/// #dnnl_graph_sycl_deallocate_f. +typedef void (*dnnl_graph_host_deallocate_f)(void *); + +/// An opaque structure to describe an allocator. +struct dnnl_graph_allocator; + +/// An allocator handle. +typedef struct dnnl_graph_allocator *dnnl_graph_allocator_t; + +/// A constant allocator handle. +typedef const struct dnnl_graph_allocator *const_dnnl_graph_allocator_t; + +/// @} dnnl_graph_api_allocator + +/// @addtogroup dnnl_graph_api_compiled_partition +/// @{ + +/// In-place pair definition. It can queried from a compiled partition +/// indicating that an input and an output of the partition can share the same +/// memory buffer for computation. In-place computation helps to reduce the +/// memory footprint and improves cache locality. But since the library may not +/// have a global view of user's application, it's possible that the tensor with +/// `input_id` is used at other places in user's computation graph. In this +/// case, the user should take the in-place pair as a hint and pass a different +/// memory buffer for output tensor to avoid overwriting the input memory buffer +/// which will probably cause unexpected incorrect results. +typedef struct { + /// The id of input tensor + size_t input_id; + + /// The id of output tensor + size_t output_id; +} dnnl_graph_inplace_pair_t; + +/// An opaque structure to describe a compiled partition. +struct dnnl_graph_compiled_partition; + +/// A compiled partition handle. +typedef struct dnnl_graph_compiled_partition *dnnl_graph_compiled_partition_t; + +/// A constant compiled partition handle. +typedef const struct dnnl_graph_compiled_partition + *const_dnnl_graph_compiled_partition_t; + +/// @} dnnl_graph_api_compiled_partition + +/// @addtogroup dnnl_graph_api_tensor +/// @{ + +/// An opaque structure to describe a tensor. +struct dnnl_graph_tensor; + +/// A tensor handle. +typedef struct dnnl_graph_tensor *dnnl_graph_tensor_t; + +/// A constant tensor handle. +typedef const struct dnnl_graph_tensor *const_dnnl_graph_tensor_t; + +/// @} dnnl_graph_api_tensor + +/// @} dnnl_graph_api + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.h new file mode 100644 index 0000000000000000000000000000000000000000..70d0c5460a0573ee6494c5db32af9e32d72479cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.h @@ -0,0 +1,276 @@ +/******************************************************************************* +* Copyright 2020-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_OCL_H +#define ONEAPI_DNNL_DNNL_OCL_H + +#include "oneapi/dnnl/dnnl.h" + +#include "oneapi/dnnl/dnnl_ocl_types.h" + +/// @cond DO_NOT_DOCUMENT_THIS +// Set target version for OpenCL explicitly to suppress a compiler warning. +#ifndef CL_TARGET_OPENCL_VERSION +#define CL_TARGET_OPENCL_VERSION 120 +#endif + +#include +/// @endcond + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_ocl_interop +/// @{ + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl_memory_set_data_handle() has been called, if @p memory_kind is equal +/// to dnnl_ocl_interop_usm, or +/// - dnnl_ocl_interop_memory_set_mem_object() has been called, if @p memory_kind +/// is equal to dnnl_ocl_interop_buffer. +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param memory_kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_ocl_interop_usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + dnnl_ocl_interop_memory_kind_t memory_kind, void *handle); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param memory_kind Memory allocation kind to specify the type of handles. +/// @param nhandles Number of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_ocl_interop_usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create_v2(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + dnnl_ocl_interop_memory_kind_t memory_kind, int nhandles, + void **handles); +#endif + +/// Returns the memory allocation kind associated with a memory object. +/// +/// @param memory Memory to query. +/// @param memory_kind Output underlying memory allocation kind of the memory +/// object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_memory_get_memory_kind( + const_dnnl_memory_t memory, + dnnl_ocl_interop_memory_kind_t *memory_kind); + +/// Returns an OpenCL memory object associated with a memory object. +/// +/// @param memory Memory object. +/// @param mem_object Output OpenCL memory object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_memory_get_mem_object( + const_dnnl_memory_t memory, cl_mem *mem_object); + +/// Sets OpenCL memory object associated with a memory object. +/// +/// For behavioral details, see dnnl_memory_set_data_handle(). +/// +/// @param memory Memory object. +/// @param mem_object OpenCL memory object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_memory_set_mem_object( + dnnl_memory_t memory, cl_mem mem_object); + +/// Retrieves a cache blob ID for the OpenCL device. +/// +/// @warning +/// This API is intended to be used with +/// #dnnl_ocl_interop_engine_get_cache_blob() and +/// #dnnl_ocl_interop_engine_create_from_cache_blob(). The returned cache +/// blob ID can only be used as an ID of the cache blob returned by +/// #dnnl_ocl_interop_engine_get_cache_blob(). +/// +/// @note The cache blob ID can be empty (@p size will be 0 and +/// @p cache_blob_id will be nullptr) if oneDNN doesn't have anything to +/// put in the cache blob. (#dnnl_ocl_interop_engine_get_cache_blob will +/// return an empty cache blob). +/// +/// @param device An OpenCL device. +/// @param size Size of the cache blob ID in bytes. +/// @param cache_blob_id Cache blob id of size @p size. If +/// the @p cache_blob_id is nullptr then the size of the cache blob ID is +/// returned in @p size. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_cache_blob_id( + cl_device_id device, size_t *size, uint8_t *cache_blob_id); + +/// Retrieves a cache blob associated with the given engine. +/// +/// @note The cache blob can be empty (@p size will be 0 and @p cache_blob +/// will be nullptr) if oneDNN doesn't have anything to put in the cache +/// blob. It's the user's responsibility to check whether it's empty +/// prior to passing it to +/// #dnnl_ocl_interop_engine_create_from_cache_blob(). +/// +/// @param engine Engine to query for the cache blob. +/// @param size Size of the cache blob in bytes. +/// @param cache_blob Cache blob of size @p size. If the @p cache_blob is +/// nullptr then the size of the cache blob is returned in @p size. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_cache_blob( + dnnl_engine_t engine, size_t *size, uint8_t *cache_blob); + +/// Creates an engine from the given cache blob. +/// +/// @param engine Output engine. +/// @param device The OpenCL device that this engine will encapsulate. +/// @param context The OpenCL context (containing the device) that this +/// engine will use for all operations. +/// @param size Size of the cache blob in bytes. +/// @param cache_blob Cache blob of size @p size. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create_from_cache_blob( + dnnl_engine_t *engine, cl_device_id device, cl_context context, + size_t size, const uint8_t *cache_blob); + +/// Creates an engine associated with an OpenCL device and an OpenCL context. +/// +/// @param engine Output engine. +/// @param device Underlying OpenCL device to use for the engine. +/// @param context Underlying OpenCL context to use for the engine. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create( + dnnl_engine_t *engine, cl_device_id device, cl_context context); + +/// Returns the OpenCL context associated with an engine. +/// +/// @param engine Engine to query. +/// @param context Output underlying OpenCL context of the engine. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_context( + dnnl_engine_t engine, cl_context *context); + +/// Returns the OpenCL device associated with an engine. +/// +/// @param engine Engine to query. +/// @param device Output underlying OpenCL device of the engine. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_get_device( + dnnl_engine_t engine, cl_device_id *device); + +/// Creates an execution stream for a given engine associated with +/// an OpenCL command queue. +/// +/// @param stream Output execution stream. +/// @param engine Engine to create the execution stream on. +/// @param queue OpenCL command queue to use. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_stream_create( + dnnl_stream_t *stream, dnnl_engine_t engine, cl_command_queue queue); + +/// Returns the OpenCL command queue associated with an execution stream. +/// +/// @param stream Execution stream to query. +/// @param queue Output OpenCL command queue. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_stream_get_command_queue( + dnnl_stream_t stream, cl_command_queue *queue); + +/// Executes computations specified by the primitive in a specified stream and +/// returns an OpenCL event. +/// +/// @param primitive Primitive to execute. +/// @param stream Stream to use. +/// @param nargs Number of arguments. +/// @param args Array of arguments. Each argument is an +/// pair. The index is one of the `DNNL_ARG_*` +/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see +/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory +/// descriptor as that returned by +/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index). +/// @param deps A pointer to a vector of size @p ndeps that contains +/// dependencies. +/// @param ndeps Number of dependencies. +/// @param return_event Output event. It's the user's responsibility to +/// manage lifetime of the event. Can be NULL. When @p stream is in-order +/// NULL will be returned. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ocl_interop_primitive_execute( + const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, + const dnnl_exec_arg_t *args, const cl_event *deps, int ndeps, + cl_event *return_event); + +/// @} dnnl_api_ocl_interop + +/// @} dnnl_api_interop + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..de3b4150b8ae3d4f41d998a73d219f69928d2063 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.hpp @@ -0,0 +1,445 @@ +/******************************************************************************* +* Copyright 2020-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_OCL_HPP +#define ONEAPI_DNNL_DNNL_OCL_HPP + +#include "oneapi/dnnl/dnnl.hpp" + +/// @cond DO_NOT_DOCUMENT_THIS +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl_ocl.h" + +#include +/// @endcond + +/// @addtogroup dnnl_api +/// @{ + +namespace dnnl { + +/// @addtogroup dnnl_api_interop Runtime interoperability API +/// API extensions to interact with the underlying run-time. +/// @{ + +/// @addtogroup dnnl_api_ocl_interop OpenCL interoperability API +/// API extensions to interact with the underlying OpenCL run-time. +/// +/// @sa @ref dev_guide_opencl_interoperability in developer guide +/// @{ + +/// OpenCL interoperability namespace +namespace ocl_interop { + +/// Memory allocation kind. +enum class memory_kind { + /// USM (device, shared, host, or unknown) memory allocation kind. + usm = dnnl_ocl_interop_usm, + /// Buffer memory allocation kind - default. + buffer = dnnl_ocl_interop_buffer, +}; + +/// Converts a memory allocation kind enum value from C++ API to C API type. +/// +/// @param akind C++ API memory allocation kind enum value. +/// @returns Corresponding C API memory allocation kind enum value. +inline dnnl_ocl_interop_memory_kind_t convert_to_c(memory_kind akind) { + return static_cast(akind); +} + +/// Returns the cache blob ID of the OpenCL device. +/// +/// @warning +/// This API is intended to be used with +/// #dnnl::ocl_interop::get_engine_cache_blob() and +/// #dnnl::ocl_interop::make_engine(cl_device_id, cl_context, const std::vector &). +/// The returned cache blob ID can only be used as an ID of the cache blob +/// returned by #dnnl::ocl_interop::get_engine_cache_blob(). +/// +/// @note The cache blob ID can be empty (@p size will be 0 and +/// @p cache_blob_id will be nullptr) if oneDNN doesn't have anything to +/// put in the cache blob. (#dnnl_ocl_interop_engine_get_cache_blob will +/// return an empty cache blob). +/// +/// @param device An OpenCL device. +/// @returns A vector containing the cache blob ID. +inline std::vector get_engine_cache_blob_id(cl_device_id device) { + size_t size = 0; + error::wrap_c_api( + dnnl_ocl_interop_engine_get_cache_blob_id(device, &size, nullptr), + "could not get an engine cache blob id size"); + + std::vector cache_blob_id(size); + error::wrap_c_api(dnnl_ocl_interop_engine_get_cache_blob_id( + device, &size, cache_blob_id.data()), + "could not get an engine cache blob id"); + return cache_blob_id; +} + +/// Returns a cache blob for the engine. +/// +/// @note The cache blob vector can be empty if oneDNN doesn't have anything +/// to put in the cache blob. It's the user's responsibility to check +/// whether it's empty prior to passing it to +/// #dnnl::ocl_interop::make_engine(cl_device_id, cl_context, const std::vector &) +/// +/// @param aengine Engine to query for the cache blob. +/// @returns Vector containing the cache blob. +inline std::vector get_engine_cache_blob(const engine &aengine) { + size_t size = 0; + error::wrap_c_api(dnnl_ocl_interop_engine_get_cache_blob( + aengine.get(), &size, nullptr), + "could not get an engine cache blob size"); + + std::vector cache_blob(size); + error::wrap_c_api(dnnl_ocl_interop_engine_get_cache_blob( + aengine.get(), &size, cache_blob.data()), + "could not get an engine cache blob"); + return cache_blob; +} + +/// Constructs an engine from the given cache blob. +/// +/// @param device The OpenCL device that this engine will encapsulate. +/// @param context The OpenCL context (containing the device) that this +/// engine will use for all operations. +/// @param cache_blob Cache blob. +/// @returns An engine. +inline engine make_engine(cl_device_id device, cl_context context, + const std::vector &cache_blob) { + dnnl_engine_t c_engine; + error::wrap_c_api( + dnnl_ocl_interop_engine_create_from_cache_blob(&c_engine, device, + context, cache_blob.size(), cache_blob.data()), + "could not create an engine from cache blob"); + return engine(c_engine); +} + +/// Constructs an engine from OpenCL device and context objects. +/// +/// @param device The OpenCL device that this engine will encapsulate. +/// @param context The OpenCL context (containing the device) that this +/// engine will use for all operations. +/// @returns An engine. +inline engine make_engine(cl_device_id device, cl_context context) { + dnnl_engine_t c_engine; + error::wrap_c_api( + dnnl_ocl_interop_engine_create(&c_engine, device, context), + "could not create an engine"); + return engine(c_engine); +} + +/// Returns OpenCL context associated with the engine. +/// +/// @param aengine An engine. +/// @returns Underlying OpenCL context. +inline cl_context get_context(const engine &aengine) { + cl_context context = nullptr; + error::wrap_c_api( + dnnl_ocl_interop_engine_get_context(aengine.get(), &context), + "could not get an OpenCL context from an engine"); + return context; +} + +/// Returns OpenCL device associated with the engine. +/// +/// @param aengine An engine. +/// @returns Underlying OpenCL device. +inline cl_device_id get_device(const engine &aengine) { + cl_device_id device = nullptr; + error::wrap_c_api(dnnl_ocl_interop_get_device(aengine.get(), &device), + "could not get an OpenCL device from an engine"); + return device; +} + +/// Constructs an execution stream for the specified engine and OpenCL queue. +/// +/// @param aengine Engine to create the stream on. +/// @param queue OpenCL queue to use for the stream. +/// @returns An execution stream. +inline stream make_stream(const engine &aengine, cl_command_queue queue) { + dnnl_stream_t c_stream; + error::wrap_c_api( + dnnl_ocl_interop_stream_create(&c_stream, aengine.get(), queue), + "could not create a stream"); + return stream(c_stream); +} + +/// Returns OpenCL queue object associated with the execution stream. +/// +/// @param astream An execution stream. +/// @returns Underlying OpenCL queue. +inline cl_command_queue get_command_queue(const stream &astream) { + cl_command_queue queue = nullptr; + error::wrap_c_api( + dnnl_ocl_interop_stream_get_command_queue(astream.get(), &queue), + "could not get an OpenCL command queue from a stream"); + return queue; +} + +/// Returns the OpenCL memory object associated with the memory object. +/// +/// @param amemory A memory object. +/// @returns Underlying OpenCL memory object. +inline cl_mem get_mem_object(const memory &amemory) { + cl_mem mem_object; + error::wrap_c_api( + dnnl_ocl_interop_memory_get_mem_object(amemory.get(), &mem_object), + "could not get OpenCL buffer object from a memory object"); + return mem_object; +} + +/// Sets the OpenCL memory object associated with the memory object. +/// +/// For behavioral details see memory::set_data_handle(). +/// +/// @param amemory A memory object. +/// @param mem_object OpenCL cl_mem object to use as the underlying +/// storage. It must have at least get_desc().get_size() bytes +/// allocated. +inline void set_mem_object(memory &amemory, cl_mem mem_object) { + error::wrap_c_api( + dnnl_ocl_interop_memory_set_mem_object(amemory.get(), mem_object), + "could not set OpenCL buffer object from a memory object"); +} + +/// Returns the memory allocation kind associated with a memory object. +/// +/// @param amemory A memory object. +/// +/// @returns The underlying memory allocation kind of the memory object. +inline memory_kind get_memory_kind(const memory &amemory) { + dnnl_ocl_interop_memory_kind_t ckind; + error::wrap_c_api( + dnnl_ocl_interop_memory_get_memory_kind(amemory.get(), &ckind), + "could not get memory kind"); + return static_cast(ckind); +} + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_ocl_interop_usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// If the @p handles vector is not provided the library will allocate all +/// buffers as if all handles have the special value DNNL_MEMORY_ALLOCATE. +/// +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, + std::vector handles = {}) { + if (handles.empty()) { + const int nhandles = memory_desc.get_num_handles(); + handles.resize(nhandles, DNNL_MEMORY_ALLOCATE); + } + + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_ocl_interop_memory_create_v2(&c_memory, memory_desc.get(), + aengine.get(), convert_to_c(kind), (int)handles.size(), + handles.data()), + "could not create a memory"); + return memory(c_memory); +} + +/// Constructs a memory object with multiple OpenCL buffers. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param mem_objects A vector of OpenCL buffers to use. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, std::vector mem_objects) { + const int nhandles = memory_desc.get_num_handles(); + std::vector handles(nhandles, DNNL_MEMORY_NONE); + memory amemory(memory_desc, aengine, handles); + for (int i = 0; i < nhandles; i++) + amemory.set_data_handle(mem_objects[i], i); + return amemory; +} + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is +/// equal to dnnl::ocl_interop::memory_kind::usm, or +/// - dnnl::ocl_interop::set_mem_object() has been called, if @p memory_kind is +/// equal to dnnl::ocl_interop::memory_kind::buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::ocl_interop::memory_kind::usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to +/// dnnl::ocl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, void *handle) { + return make_memory( + memory_desc, aengine, kind, std::vector {handle}); +} + +/// Constructs a memory object from an OpenCL buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param mem_object An OpenCL buffer to use. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, cl_mem mem_object) { + return make_memory(memory_desc, aengine, std::vector {mem_object}); +} +#else + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is +/// equal to dnnl::ocl_interop::memory_kind::usm, or +/// - dnnl::ocl_interop::set_mem_object() has been called, if @p memory_kind is +/// equal to dnnl::ocl_interop::memory_kind::buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::ocl_interop::memory_kind::usm. +/// - An OpenCL buffer. In this case the library doesn't own the buffer. +/// Requires @p memory_kind be equal to be equal to +/// dnnl::ocl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, + void *handle = DNNL_MEMORY_ALLOCATE) { + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_ocl_interop_memory_create(&c_memory, memory_desc.get(), + aengine.get(), convert_to_c(kind), handle), + "could not create a memory"); + return memory(c_memory); +} + +/// Constructs a memory object from an OpenCL buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param mem_object An OpenCL buffer to use. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, cl_mem mem_object) { + memory amemory(memory_desc, aengine, DNNL_MEMORY_NONE); + set_mem_object(amemory, mem_object); + return amemory; +} +#endif + +/// Executes computations specified by the primitive in a specified stream and +/// returns a SYCL event. +/// +/// Arguments are passed via an arguments map containing +/// pairs. The index must be one of the `DNNL_ARG_*` +/// values such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor +/// matching the one returned by +/// #dnnl::primitive_desc::query_md(#query::exec_arg_md, index) unless using +/// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL). +/// +/// @param aprimitive Primitive to execute. +/// @param astream Stream object. The stream must belong to the same engine +/// as the primitive. +/// @param args Arguments map. +/// @param deps Optional vector with `cl_event` dependencies. +/// +/// @returns Output event. It's the user's responsibility to manage lifetime +/// of the event. +inline cl_event execute(const dnnl::primitive &aprimitive, + const stream &astream, const std::unordered_map &args, + const std::vector &deps = {}) { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a : args) + c_args.push_back({a.first, a.second.get()}); + + const cl_event *c_deps = deps.empty() ? nullptr : deps.data(); + + cl_event return_event; + error::wrap_c_api(dnnl_ocl_interop_primitive_execute(aprimitive.get(), + astream.get(), (int)c_args.size(), c_args.data(), + c_deps, (int)deps.size(), &return_event), + "could not execute a primitive"); + return return_event; +} + +} // namespace ocl_interop + +/// @} dnnl_api_ocl_interop + +/// @} dnnl_api_interop + +} // namespace dnnl + +/// @} dnnl_api + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl_types.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl_types.h new file mode 100644 index 0000000000000000000000000000000000000000..39a216650a6894152a5b3d7af84d157ca283a202 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ocl_types.h @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_OCL_TYPES_H +#define ONEAPI_DNNL_DNNL_OCL_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_ocl_interop +/// @{ + +/// Memory allocation kind. +typedef enum { + /// USM (device, shared, host, or unknown) memory allocation kind. + dnnl_ocl_interop_usm, + /// Buffer memory allocation kind - default. + dnnl_ocl_interop_buffer, +} dnnl_ocl_interop_memory_kind_t; + +/// @} dnnl_api_ocl_interop + +/// @} dnnl_api_interop + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.h new file mode 100644 index 0000000000000000000000000000000000000000..a4abe8518360fbe0473b6485d3c2a8180a276b97 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.h @@ -0,0 +1,199 @@ +/******************************************************************************* +* Copyright 2020-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_SYCL_H +#define ONEAPI_DNNL_DNNL_SYCL_H + +#include "oneapi/dnnl/dnnl.h" + +#include "oneapi/dnnl/dnnl_sycl_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_sycl_interop +/// @{ + +/// Creates an engine associated with a SYCL device and a SYCL context. +/// +/// @param engine Output engine. +/// @param device Pointer to the SYCL device to use for the engine. +/// @param context Pointer to the SYCL context to use for the engine. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_engine_create( + dnnl_engine_t *engine, const void *device, const void *context); + +/// Returns the SYCL context associated with an engine. +/// +/// @param engine Engine to query. +/// @param context Pointer to the underlying SYCL context of the engine. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_engine_get_context( + dnnl_engine_t engine, void **context); + +/// Returns the SYCL device associated with an engine. +/// +/// @param engine Engine to query. +/// @param device Pointer to the underlying SYCL device of the engine. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_engine_get_device( + dnnl_engine_t engine, void **device); + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl_memory_set_data_handle() had been called, if @p memory_kind is equal +/// to dnnl_sycl_interop_usm, or +/// - dnnl_sycl_interop_memory_set_buffer() has been called, if @p memory_kind +/// is equal to dnnl_sycl_interop_buffer. +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param memory_kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_sycl_interop_usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl_sycl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + dnnl_sycl_interop_memory_kind_t memory_kind, void *handle); + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory Output memory object. +/// @param memory_desc Memory descriptor. +/// @param engine Engine to use. +/// @param memory_kind Memory allocation kind to specify the type of handles. +/// @param nhandles Number of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl_sycl_interop_usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl_sycl_interop_buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create_v2(dnnl_memory_t *memory, + const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine, + dnnl_sycl_interop_memory_kind_t memory_kind, int nhandles, + void **handles); +#endif + +/// Returns the memory allocation kind associated with a memory object. +/// +/// @param memory Memory to query. +/// @param memory_kind Output underlying memory allocation kind of the memory +/// object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_memory_get_memory_kind( + const_dnnl_memory_t memory, + dnnl_sycl_interop_memory_kind_t *memory_kind); + +/// Sets a SYCL buffer for a memory object. +/// +/// @param memory Memory object. +/// @param buffer SYCL buffer to be set in the memory object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_memory_set_buffer( + dnnl_memory_t memory, void *buffer); + +/// Creates an execution stream for a given engine associated with a SYCL +/// queue. +/// +/// @param stream Output execution stream. +/// @param engine Engine to create the execution stream on. +/// @param queue SYCL queue to use. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_stream_create( + dnnl_stream_t *stream, dnnl_engine_t engine, void *queue); + +/// Returns the SYCL queue associated with an execution stream. +/// +/// @param stream Execution stream to query. +/// @param queue Output SYCL command queue. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_stream_get_queue( + dnnl_stream_t stream, void **queue); + +/// Executes computations specified by the primitive in a specified stream and +/// returns a SYCL event. +/// +/// @param primitive Primitive to execute. +/// @param stream Stream to use. +/// @param nargs Number of arguments. +/// @param args Array of arguments. Each argument is an +/// pair. The index is one of the `DNNL_ARG_*` +/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see +/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory +/// descriptor as that returned by +/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index). +/// @param deps A pointer to std::vector that contains +/// dependencies. +/// @param return_event Output event. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_sycl_interop_primitive_execute( + const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, + const dnnl_exec_arg_t *args, const void *deps, void *return_event); + +/// @} dnnl_api_sycl_interop + +/// @} dnnl_api_interop + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f7d8f559c122162176e7c43d6c4d52fea980486 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.hpp @@ -0,0 +1,384 @@ +/******************************************************************************* +* Copyright 2020-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_SYCL_HPP +#define ONEAPI_DNNL_DNNL_SYCL_HPP + +/// @cond DO_NOT_DOCUMENT_THIS +#include +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#error "Unsupported compiler" +#endif + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_sycl.h" + +/// @endcond + +/// @addtogroup dnnl_api +/// @{ + +namespace dnnl { + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_sycl_interop SYCL interoperability API +/// API extensions to interact with the underlying SYCL run-time. +/// +/// @sa @ref dev_guide_dpcpp_interoperability in developer guide +/// @{ + +/// SYCL interoperability namespace +namespace sycl_interop { + +/// Memory allocation kind. +enum class memory_kind { + /// USM (device, shared, host, or unknown) memory allocation kind - default. + usm = dnnl_sycl_interop_usm, + /// Buffer memory allocation kind. + buffer = dnnl_sycl_interop_buffer, +}; + +/// Converts a memory allocation kind enum value from C++ API to C API type. +/// +/// @param akind C++ API memory allocation kind enum value. +/// @returns Corresponding C API memory allocation kind enum value. +inline dnnl_sycl_interop_memory_kind_t convert_to_c(memory_kind akind) { + return static_cast(akind); +} + +/// Constructs an engine from SYCL device and context objects. +/// +/// @param adevice SYCL device. +/// @param acontext SYCL context. +/// +/// @returns Created engine. +inline engine make_engine( + const sycl::device &adevice, const sycl::context &acontext) { + dnnl_engine_t aengine; + error::wrap_c_api(dnnl_sycl_interop_engine_create(&aengine, + static_cast(&adevice), + static_cast(&acontext)), + "could not create an engine"); + return engine(aengine); +} + +/// Returns the SYCL context associated with an engine. +/// +/// @param aengine Engine to query. +/// +/// @returns The underlying SYCL device of the engine. +inline sycl::context get_context(const engine &aengine) { + void *ctx_ptr; + error::wrap_c_api( + dnnl_sycl_interop_engine_get_context(aengine.get(), &ctx_ptr), + "could not get a context handle"); + auto ctx = *static_cast(ctx_ptr); + return ctx; +} + +/// Returns the SYCL device associated with an engine. +/// +/// @param aengine Engine to query. +/// +/// @returns The underlying SYCL context of the engine. +inline sycl::device get_device(const engine &aengine) { + void *dev_ptr; + error::wrap_c_api( + dnnl_sycl_interop_engine_get_device(aengine.get(), &dev_ptr), + "could not get a device handle"); + auto dev = *static_cast(dev_ptr); + return dev; +} + +/// Creates an execution stream for a given engine associated with a SYCL +/// queue. +/// +/// @param aengine Engine object to use for the stream. +/// @param aqueue SYCL queue to use for the stream. +/// +/// @returns An execution stream. +inline stream make_stream(const engine &aengine, sycl::queue &aqueue) { + dnnl_stream_t astream; + error::wrap_c_api( + dnnl_sycl_interop_stream_create(&astream, aengine.get(), &aqueue), + "could not create a stream"); + return stream(astream); +} + +/// Returns the SYCL queue associated with an execution stream. +/// +/// @param astream Execution stream to query. +/// +/// @returns SYCL queue object. +inline sycl::queue get_queue(const stream &astream) { + void *queue_ptr; + error::wrap_c_api( + dnnl_sycl_interop_stream_get_queue(astream.get(), &queue_ptr), + "could not get a stream handle"); + auto queue = *static_cast(queue_ptr); + return queue; +} + +/// Returns the SYCL buffer associated with a memory object. +/// +/// Throws an exception if the memory allocation kind associated with the +/// memory object is not equal to dnnl::sycl_interop::memory_kind::buffer. +/// +/// @tparam T Type of the requested buffer. +/// @tparam ndims Number of dimensions of the requested buffer. +/// @param amemory Memory object. +/// +/// @returns SYCL buffer associated with the memory object. +template +sycl::buffer get_buffer(const memory &amemory) { + static_assert(ndims == 1, "only 1D buffers supported"); + + // XXX: workaround: when CPU runtime is not SYCL and amemory was created + // for CPU engine `get_buffer` should return an error. Use interop API to + // implement the check. + dnnl_sycl_interop_memory_kind_t ckind; + error::wrap_c_api( + dnnl_sycl_interop_memory_get_memory_kind(amemory.get(), &ckind), + "could not get SYCL buffer object"); + + void *handle_ptr; + error::wrap_c_api(dnnl_memory_get_data_handle(amemory.get(), &handle_ptr), + "could not get SYCL buffer object"); + + // XXX: workaround: zero-range buffer cannot be constructed. + if (!handle_ptr) return sycl::buffer(sycl::range<1>(1)); + + auto &buf_u8 = *static_cast *>(handle_ptr); + + auto range = sycl::range<1>(buf_u8.byte_size() / sizeof(T)); + return buf_u8.reinterpret(range); +} + +/// Sets SYCL buffer associated with a memory object. +/// +/// @tparam T Type of the buffer. +/// @tparam ndims Number of dimensions of the buffer. +/// @param amemory Memory object to change. +/// @param abuffer SYCL buffer. +template +void set_buffer(memory &amemory, sycl::buffer &abuffer) { + auto range = sycl::range<1>(abuffer.byte_size()); + auto buf_u8 = abuffer.template reinterpret(range); + error::wrap_c_api(dnnl_sycl_interop_memory_set_buffer( + amemory.get(), static_cast(&buf_u8)), + "could not set SYCL buffer object"); +} + +/// Returns the memory allocation kind associated with a memory object. +/// +/// @param amemory A memory object. +/// +/// @returns The underlying memory allocation kind of the memory object. +inline memory_kind get_memory_kind(const memory &amemory) { + dnnl_sycl_interop_memory_kind_t ckind; + error::wrap_c_api( + dnnl_sycl_interop_memory_get_memory_kind(amemory.get(), &ckind), + "could not get memory kind"); + return static_cast(ckind); +} + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Creates a memory object with multiple handles. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handles. +/// @param handles Handles of the memory buffers to use as underlying storages. +/// For each element of the @p handles array the following applies: +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::sycl_interop::memory_kind::usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl::sycl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// If the @p handles vector is not provided the library will allocate all +/// buffers as if all handles have the special value DNNL_MEMORY_ALLOCATE. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, + std::vector handles = {}) { + if (handles.empty()) { + const int nhandles = memory_desc.get_num_handles(); + handles.resize(nhandles, DNNL_MEMORY_ALLOCATE); + } + + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_sycl_interop_memory_create_v2(&c_memory, memory_desc.get(), + aengine.get(), convert_to_c(kind), (int)handles.size(), + handles.data()), + "could not create a memory"); + return memory(c_memory); +} + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is +/// equal to dnnl::sycl_interop::memory_kind::usm, or +/// - dnnl::sycl_interop::set_buffer() has been called, if @p memory_kind is +/// equal to dnnl::sycl_interop::memory_kind::buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::sycl_interop::memory_kind::usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl::sycl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, void *handle) { + return make_memory( + memory_desc, aengine, kind, std::vector {handle}); +} +#else + +/// Creates a memory object. +/// +/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the +/// constructed memory object will have the underlying buffer set. In this +/// case, the buffer will be initialized as if: +/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is +/// equal to dnnl::sycl_interop::memory_kind::usm, or +/// - dnnl::sycl_interop::set_buffer() has been called, if @p memory_kind is +/// equal to dnnl::sycl_interop::memory_kind::buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param kind Memory allocation kind to specify the type of handle. +/// @param handle Handle of the memory buffer to use as an underlying storage. +/// - A USM pointer to the user-allocated buffer. In this case the library +/// doesn't own the buffer. Requires @p memory_kind to be equal to +/// dnnl::sycl_interop::memory_kind::usm. +/// - A pointer to SYCL buffer. In this case the library doesn't own the +/// buffer. Requires @p memory_kind be equal to be equal to +/// dnnl::sycl_interop::memory_kind::buffer. +/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to +/// allocate the buffer that corresponds to the memory allocation kind +/// @p memory_kind for the memory object. In this case the library +/// owns the buffer. +/// - The DNNL_MEMORY_NONE specific value. Instructs the library to +/// create memory object without an underlying buffer. +/// +/// @returns Created memory object. +inline memory make_memory(const memory::desc &memory_desc, + const engine &aengine, memory_kind kind, + void *handle = DNNL_MEMORY_ALLOCATE) { + dnnl_memory_t c_memory; + error::wrap_c_api( + dnnl_sycl_interop_memory_create(&c_memory, memory_desc.get(), + aengine.get(), convert_to_c(kind), handle), + "could not create a memory"); + return memory(c_memory); +} +#endif + +/// Constructs a memory object from a SYCL buffer. +/// +/// @param memory_desc Memory descriptor. +/// @param aengine Engine to use. +/// @param abuffer A SYCL buffer to use. +/// +/// @returns Created memory object. +template +memory make_memory(const memory::desc &memory_desc, const engine &aengine, + sycl::buffer &abuffer) { + memory amemory(memory_desc, aengine, DNNL_MEMORY_NONE); + set_buffer(amemory, abuffer); + return amemory; +} + +/// Executes computations specified by the primitive in a specified stream and +/// returns a SYCL event. +/// +/// Arguments are passed via an arguments map containing +/// pairs. The index must be one of the `DNNL_ARG_*` +/// values such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor +/// matching the one returned by +/// #dnnl::primitive_desc::query_md(#query::exec_arg_md, index) unless using +/// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL). +/// +/// @param aprimitive Primitive to execute. +/// @param astream Stream object. The stream must belong to the same engine +/// as the primitive. +/// @param args Arguments map. +/// @param deps Optional vector with `sycl::event` dependencies. +/// +/// @returns Output event. +inline sycl::event execute(const dnnl::primitive &aprimitive, + const stream &astream, const std::unordered_map &args, + const std::vector &deps = {}) { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a : args) + c_args.push_back({a.first, a.second.get()}); + + sycl::event return_event; + error::wrap_c_api( + dnnl_sycl_interop_primitive_execute(aprimitive.get(), astream.get(), + (int)c_args.size(), c_args.data(), &deps, &return_event), + "could not execute a primitive"); + return return_event; +} + +} // namespace sycl_interop + +/// @} dnnl_api_sycl_interop + +/// @} dnnl_api_interop + +} // namespace dnnl + +/// @} dnnl_api + +#endif // DNNL_SYCL_HPP diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl_types.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl_types.h new file mode 100644 index 0000000000000000000000000000000000000000..8314075ef0007b1ea17d7826202232530085e9cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_sycl_types.h @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_SYCL_TYPES_H +#define ONEAPI_DNNL_DNNL_SYCL_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_sycl_interop +/// @{ + +/// Memory allocation kind. +typedef enum { + /// USM (device, shared, host, or unknown) memory allocation kind - default. + dnnl_sycl_interop_usm, + /// Buffer memory allocation kind. + dnnl_sycl_interop_buffer, +} dnnl_sycl_interop_memory_kind_t; + +/// @} dnnl_api_sycl_interop + +/// @} dnnl_api_interop + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.h new file mode 100644 index 0000000000000000000000000000000000000000..d6110930cfeaf51834e6a44a2444253de1df3874 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.h @@ -0,0 +1,118 @@ +/******************************************************************************* +* Copyright 2020-2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_THREADPOOL_H +#define ONEAPI_DNNL_DNNL_THREADPOOL_H + +#include "oneapi/dnnl/dnnl_config.h" +#include "oneapi/dnnl/dnnl_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_threadpool_interop +/// @{ + +/// Creates an execution stream with specified threadpool. +/// +/// @sa @ref dev_guide_threadpool +/// +/// @param stream Output execution stream. +/// @param engine Engine to create the execution stream on. +/// @param threadpool Pointer to an instance of a C++ class that implements +/// dnnl::threapdool_iface interface. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_create( + dnnl_stream_t *stream, dnnl_engine_t engine, void *threadpool); + +/// Returns a threadpool to be used by the execution stream. +/// +/// @sa @ref dev_guide_threadpool +/// +/// @param astream Execution stream. +/// @param threadpool Output pointer to an instance of a C++ class that +/// implements dnnl::threapdool_iface interface. Set to NULL if the +/// stream was created without threadpool. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_get_threadpool( + dnnl_stream_t astream, void **threadpool); + +/// Sets the maximum concurrency assumed by oneDNN when outside a +/// parallel call. +/// +/// @param max_concurrency The maximum concurrency assumed by oneDNN +/// when outside a parallel call. This is a threadlocal setting. +/// @returns #dnnl_success on success and a status describing the +/// error otherwise. +dnnl_status_t DNNL_API dnnl_threadpool_interop_set_max_concurrency( + int max_concurrency); + +/// Gets the maximum concurrency assumed by oneDNN when outside a +/// parallel call. +/// +/// @param max_concurrency The maximum concurrency assumed by oneDNN +/// when outside a parallel call. This is a threadlocal setting. +/// @returns #dnnl_success on success and a status describing the +/// error otherwise. +dnnl_status_t DNNL_API dnnl_threadpool_interop_get_max_concurrency( + int *max_concurrency); + +/// @copydoc dnnl_sgemm() +/// @param threadpool A pointer to a threadpool interface (only when built with +/// the THREADPOOL CPU runtime). +dnnl_status_t DNNL_API dnnl_threadpool_interop_sgemm(char transa, char transb, + dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, + dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, + dnnl_dim_t ldc, void *threadpool); + +/// @copydoc dnnl_gemm_u8s8s32() +/// @param threadpool A pointer to a threadpool interface (only when built with +/// the THREADPOOL CPU runtime). +dnnl_status_t DNNL_API dnnl_threadpool_interop_gemm_u8s8s32(char transa, + char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, + float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, + const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, + dnnl_dim_t ldc, const int32_t *co, void *threadpool); + +/// @copydoc dnnl_gemm_s8s8s32() +/// @param threadpool A pointer to a threadpool interface (only when built with +/// the THREADPOOL CPU runtime). +dnnl_status_t DNNL_API dnnl_threadpool_interop_gemm_s8s8s32(char transa, + char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, + float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, + const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, + dnnl_dim_t ldc, const int32_t *co, void *threadpool); + +/// @} dnnl_api_threadpool_interop + +/// @} dnnl_api_interop + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.hpp new file mode 100644 index 0000000000000000000000000000000000000000..849465a540ca32c0356af80d514d24a743bc9fe2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.hpp @@ -0,0 +1,113 @@ +/******************************************************************************* +* Copyright 2020-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_THREADPOOL_HPP +#define ONEAPI_DNNL_DNNL_THREADPOOL_HPP + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_threadpool.h" + +#include "oneapi/dnnl/dnnl_threadpool_iface.hpp" + +/// @addtogroup dnnl_api +/// @{ + +namespace dnnl { + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_threadpool_interop Threadpool interoperability API +/// API extensions to interact with the underlying Threadpool run-time. +/// @{ + +/// Threadpool interoperability namespace +namespace threadpool_interop { + +/// Constructs an execution stream for the specified engine and threadpool. +/// +/// @sa @ref dev_guide_threadpool +/// +/// @param aengine Engine to create the stream on. +/// @param threadpool Pointer to an instance of a C++ class that implements +/// dnnl::threapdool_iface interface. +/// @returns An execution stream. +inline dnnl::stream make_stream( + const dnnl::engine &aengine, threadpool_iface *threadpool) { + dnnl_stream_t c_stream; + dnnl::error::wrap_c_api(dnnl_threadpool_interop_stream_create( + &c_stream, aengine.get(), threadpool), + "could not create stream"); + return dnnl::stream(c_stream); +} + +/// Returns the pointer to a threadpool that is used by an execution stream. +/// +/// @sa @ref dev_guide_threadpool +/// +/// @param astream An execution stream. +/// @returns Output pointer to an instance of a C++ class that implements +/// dnnl::threapdool_iface interface or NULL if the stream was created +/// without threadpool. +inline threadpool_iface *get_threadpool(const dnnl::stream &astream) { + void *tp; + dnnl::error::wrap_c_api( + dnnl_threadpool_interop_stream_get_threadpool(astream.get(), &tp), + "could not get stream threadpool"); + return static_cast(tp); +} + +/// @copydoc dnnl_threadpool_interop_sgemm() +inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, + dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, + const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc, + threadpool_iface *threadpool) { + return static_cast(dnnl_threadpool_interop_sgemm(transa, transb, M, + N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool)); +} +/// @copydoc dnnl_threadpool_interop_gemm_u8s8s32() +inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, + dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, + dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, + float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co, + threadpool_iface *threadpool) { + return static_cast(dnnl_threadpool_interop_gemm_u8s8s32(transa, + transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, + ldc, co, threadpool)); +} + +/// @copydoc dnnl_threadpool_interop_gemm_s8s8s32() +inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, + dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, + dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, + float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co, + threadpool_iface *threadpool) { + return static_cast(dnnl_threadpool_interop_gemm_s8s8s32(transa, + transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, + ldc, co, threadpool)); +} + +} // namespace threadpool_interop + +/// @} dnnl_api_threadpool_interop + +/// @} dnnl_api_interop + +} // namespace dnnl + +/// @} dnnl_api + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool_iface.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool_iface.hpp new file mode 100644 index 0000000000000000000000000000000000000000..271d4db7f2251d360c92223a26509aa9fc76bff0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool_iface.hpp @@ -0,0 +1,73 @@ +/******************************************************************************* +* Copyright 2020-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_THREADPOOL_IFACE_HPP +#define ONEAPI_DNNL_DNNL_THREADPOOL_IFACE_HPP + +#include +#include + +/// @addtogroup dnnl_api +/// @{ + +namespace dnnl { + +/// @addtogroup dnnl_api_interop +/// @{ + +/// @addtogroup dnnl_api_threadpool_interop +/// @{ + +namespace threadpool_interop { + +/// Abstract threadpool interface. The users are expected to subclass this +/// interface and pass an object to the library during CPU stream creation or +/// directly in case of BLAS functions. +struct threadpool_iface { + /// Returns the number of worker threads. + virtual int get_num_threads() const = 0; + + /// Returns true if the calling thread belongs to this threadpool. + virtual bool get_in_parallel() const = 0; + + /// Submits n instances of a closure for execution in parallel: + /// + /// for (int i = 0; i < n; i++) fn(i, n); + /// + virtual void parallel_for(int n, const std::function &fn) + = 0; + + /// Returns threadpool behavior flags bit mask (see below). + virtual uint64_t get_flags() const = 0; + + /// If set, parallel_for() returns immediately and oneDNN needs implement + /// waiting for the submitted closures to finish execution on its own. + static constexpr uint64_t ASYNCHRONOUS = 1; + + virtual ~threadpool_iface() {} +}; + +} // namespace threadpool_interop + +/// @} dnnl_api_threadpool_interop + +/// @} dnnl_api_interop + +} // namespace dnnl + +/// @} dnnl_api + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_types.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_types.h new file mode 100644 index 0000000000000000000000000000000000000000..4c8ac547f93e4208fcf42703d02c133b237ddbaa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_types.h @@ -0,0 +1,2936 @@ +/******************************************************************************* +* Copyright 2016-2025 Intel Corporation +* Copyright 2024 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C API types definitions + +#ifndef ONEAPI_DNNL_DNNL_TYPES_H +#define ONEAPI_DNNL_DNNL_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +/// @cond DO_NOT_DOCUMENT_THIS +#include +#include +/// @endcond + +#include "oneapi/dnnl/dnnl_config.h" + +#include "oneapi/dnnl/dnnl_common_types.h" + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_memory +/// @{ + +/// Memory format kind +typedef enum { + /// Undefined memory format kind, used for empty memory descriptors. + dnnl_format_kind_undef = 0, + /// A special format kind that indicates that the actual format will be + /// selected by a primitive automatically. + dnnl_format_kind_any, + /// A tensor in a generic format described by the stride and blocking + /// values in each dimension. + dnnl_blocked, + /// A special format kind that indicates that tensor format is opaque. + dnnl_format_kind_opaque, +#ifdef DNNL_EXPERIMENTAL_SPARSE + /// Format kind for sparse tensors. + dnnl_format_kind_sparse, +#endif + /// Parameter to allow internal only format kinds without undefined + /// behavior. This parameter is chosen to be valid for so long as + /// sizeof(int) >= 2. + dnnl_format_kind_max = 0x7fff, +} dnnl_format_kind_t; + +#ifdef DNNL_EXPERIMENTAL_SPARSE +/// Sparse encodings. +typedef enum { + /// Undefined sparse encoding kind, used for empty memory descriptors. + dnnl_sparse_encoding_undef = 0, + /// Compressed Sparse Row (CSR) encoding. + dnnl_csr, + /// An encoding that is used for an opaque storage schema for + /// tensors with unstructured sparsity. A memory descriptor with the + /// packed encoding cannot be used to create a memory object. It can + /// only be used to create a primitive descriptor to query the + /// actual memory descriptor (similar to the format tag `any`). + dnnl_packed, + /// Coordinate Sparse Encoding (COO). + dnnl_coo, +} dnnl_sparse_encoding_t; +#endif + +#ifdef DNNL_EXPERIMENTAL_PROFILING +/// Profiling data kind. +typedef enum { + /// Undefined profiling data kind. + dnnl_profiling_data_kind_undef = 0, + /// Data kind to query an execution time in nanoseconds. + dnnl_profiling_data_kind_time, +} dnnl_profiling_data_kind_t; + +#endif + +/// Memory format tag specification. +/// +/// oneDNN formats describe physical data layout. The physical layout +/// is described as a sequence of the dimensions as they are laid out in the +/// memory (from the outer-most to the inner-most). Note that this order +/// doesn't affect the logical order of the dimensions that is kept in the +/// `dims` field of the dnnl_memory_desc_t structure. The logical order of the +/// dimensions is specified by the primitive that uses the tensor. +/// +/// For example, CNN 5D tensor always has its logical dimensions in the order +/// `(batch, channels, depth, height, width)`, while the physical layout might be +/// `NCDHW` (corresponds to #dnnl_ncdhw format tag) or +/// `NDHWC` (corresponds to #dnnl_ndhwc format tag). +/// +/// ~~~cpp +/// int batch = 2, channels = 16, depth = 13, height = 13, width = 13; +/// +/// int ndims = 5; // 5D tensor +/// dnnl_dims_t dims = {batch, channels, depth, height, width}; +/// dnnl_memory_desc_t data_in_ncdhw; +/// dnnl_memory_desc_create_with_tag( +/// &data_in_ncdhw, 5, dims, dnnl_f32, dnnl_ncdhw); +/// +/// // note that in both cases dims passed are the same +/// dnnl_memory_desc_t data_in_ndhwc; +/// dnnl_memory_desc_create_with_tag( +/// &data_in_ndhwc, 5, dims, dnnl_f32, dnnl_ndhwc); +/// +/// dnnl_memory_desc_destroy(data_in_ncdhw); +/// dnnl_memory_desc_destroy(data_in_ndhwc); +/// ~~~ +/// +/// Memory format tags can be further divided into two categories: +/// - Domain-agnostic names, i.e. names the do not depend on the tensor usage +/// in the specific primitive. These names use letters from `a` to `l` to +/// denote logical dimension from 1 to 12, and form the order in which the +/// dimensions are laid in memory. For instance, #dnnl_ab is used to denote +/// 2D tensor where the second logical dimension (aka `b`) is the innermost, +/// i.e. has stride = 1, and the first logical dimension (`a`) laid out in +/// memory with stride equal to the size of second dimension. On the other +/// hand, #dnnl_ba is just transposed version of the same tensor: the +/// first dimension (`a`) becomes the innermost one. +/// - Domain-specific names, i.e. names that make sense only in the context of +/// a certain domain, such as CNN. This names are just aliases to the +/// corresponding domain-agnostic tags and used mostly for the convenience. +/// For example, #dnnl_nc is used to denote 2D CNN activations tensor +/// memory format, where channels are the innermost dimension and batch is an +/// outermost one. Moreover, #dnnl_nc is just an alias to #dnnl_ab, +/// since for oneDNN CNN primitives the logical dimensions of +/// activations tensors come in order: batch, channels, spatial. +/// In other words, batch corresponds to the first logical dimension (`a`), +/// channels correspond to the second one (`b`). +/// +/// The following domain-specific notation applies to memory format tags: +/// - @c 'n' denotes the mini-batch dimension +/// - @c 'c' denotes a channels dimension +/// - When there are multiple channel dimensions (for example, in convolution +/// weights tensor), @c 'i' and @c 'o' denote dimensions of input and output +/// channels +/// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width +/// respectively +/// +/// Upper-case letters indicate that the data is laid out in blocks for a +/// particular dimension. In such cases, the format name contains both upper- +/// and lower-case letters for that dimension with a lower-case letter preceded +/// by the block size. For example: #dnnl_nChw8c describes a format where the +/// outermost dimension is mini-batch, followed by the channel block number, +/// followed by the spatial height and width, and finally followed by 8-element +/// channel blocks. +/// +/// @sa @ref dev_guide_understanding_memory_formats +typedef enum { + /// Undefined memory format tag + dnnl_format_tag_undef = 0, + /// Undefined memory format tag. + /// The primitive selects a format automatically. + dnnl_format_tag_any, + + // Semantic agnostic section + // The physical order of dimensions is defined by the permutation of the + // characters, assuming that ab..z defines the natural order. + + // Plain formats + + dnnl_a, ///< plain 1D tensor + dnnl_ab, ///< plain 2D tensor + dnnl_abc, ///< plain 3D tensor + dnnl_abcd, ///< plain 4D tensor + dnnl_abcde, ///< plain 5D tensor + dnnl_abcdef, ///< plain 6D tensor + dnnl_abcdefg, ///< plain 7D tensor + dnnl_abcdefgh, ///< plain 8D tensor + dnnl_abcdefghi, ///< plain 9D tensor + dnnl_abcdefghij, ///< plain 10D tensor + dnnl_abcdefghijk, ///< plain 11D tensor + dnnl_abcdefghijkl, ///< plain 12D tensor + + // Permuted plain formats + + dnnl_ba, ///< permuted 2D tensor + dnnl_acb, ///< permuted 3D tensor + dnnl_bac, ///< permuted 3D tensor + dnnl_bca, ///< permuted 3D tensor + dnnl_cab, ///< permuted 3D tensor + dnnl_cba, ///< permuted 3D tensor + dnnl_abdc, ///< permuted 4D tensor + dnnl_acbd, ///< permuted 4D tensor + dnnl_acdb, ///< permuted 4D tensor + dnnl_adbc, ///< permuted 4D tensor + dnnl_adcb, ///< permuted 4D tensor + dnnl_bacd, ///< permuted 4D tensor + dnnl_bcda, ///< permuted 4D tensor + dnnl_cdab, ///< permuted 4D tensor + dnnl_cdba, ///< permuted 4D tensor + dnnl_dcab, ///< permuted 4D tensor + dnnl_abced, ///< permuted 5D tensor + dnnl_abdec, ///< permuted 5D tensor + dnnl_acbde, ///< permuted 5D tensor + dnnl_acdeb, ///< permuted 5D tensor + dnnl_adecb, ///< permuted 5D tensor + dnnl_bacde, ///< permuted 5D tensor + dnnl_bcdea, ///< permuted 5D tensor + dnnl_cdeab, ///< permuted 5D tensor + dnnl_cdeba, ///< permuted 5D tensor + dnnl_decab, ///< permuted 5D tensor + dnnl_abcdfe, ///< permuted 6D tensor + dnnl_abdefc, ///< permuted 6D tensor + dnnl_abdfce, ///< permuted 6D tensor + dnnl_acbdef, ///< permuted 6D tensor + dnnl_adefcb, ///< permuted 6D tensor + dnnl_defcab, ///< permuted 6D tensor + dnnl_abcdegf, ///< permuted 7D tensor + dnnl_abcdefhg, ///< permuted 8D tensor + dnnl_abcdefgih, ///< permuted 9D tensor + dnnl_abcdefghji, ///< permuted 10D tensor + dnnl_abcdefghikj, ///< permuted 11D tensor + dnnl_abcdefghijlk, ///< permuted 12D tensor + + // Opaque blocked formats + + dnnl_Abc16a, + dnnl_ABc16a16b, + dnnl_ABc32a32b, + dnnl_ABc4a4b, + /// 3D tensor blocked by 2nd dimension with block size 16 + dnnl_aBc16b, + dnnl_ABc16b16a, + dnnl_Abc4a, + /// 3D tensor blocked by 2nd dimension with block size 32 + dnnl_aBc32b, + /// 3D tensor blocked by 2nd dimension with block size 4 + dnnl_aBc4b, + dnnl_ABc4b16a4b, + dnnl_ABc2b8a4b, + dnnl_ABc16b16a4b, + dnnl_ABc16b16a2b, + dnnl_ABc4b4a, + dnnl_ABc8a16b2a, + dnnl_ABc8a8b, + dnnl_ABc8a4b, + /// 3D tensor blocked by 2nd dimension with block size 8 + dnnl_aBc8b, + dnnl_ABc8b16a2b, + dnnl_BAc8a16b2a, + dnnl_ABc8b8a, + dnnl_Abcd16a, + dnnl_Abcd8a, + dnnl_ABcd16a16b, + dnnl_Abcd32a, + dnnl_ABcd32a32b, + /// 4D tensor blocked by 2nd dimension with block size 16 + dnnl_aBcd16b, + dnnl_ABcd16b16a, + dnnl_aBCd16b16c, + dnnl_aBCd16c16b, + dnnl_Abcd4a, + /// 4D tensor blocked by 2nd dimension with block size 32 + dnnl_aBcd32b, + /// 4D tensor blocked by 2nd dimension with block size 4 + dnnl_aBcd4b, + dnnl_ABcd4b16a4b, + dnnl_ABcd16b16a4b, + dnnl_ABcd16b16a2b, + dnnl_ABcd4b4a, + dnnl_ABcd4a4b, + dnnl_aBCd2c4b2c, + dnnl_aBCd4b8c2b, + dnnl_aBCd4c16b4c, + dnnl_aBCd2c8b4c, + dnnl_aBCd16c16b4c, + dnnl_aBCd16c16b2c, + dnnl_aBCd4c4b, + dnnl_aBCd4b4c, + dnnl_ABcd8a16b2a, + dnnl_ABcd2b8a4b, + dnnl_ABcd8a8b, + dnnl_ABcd8a4b, + /// 4D tensor blocked by 2nd dimension with block size 8 + dnnl_aBcd8b, + dnnl_aBCd4c8b2c, + dnnl_ABcd8b16a2b, + dnnl_aBCd8b16c2b, + dnnl_BAcd8a16b2a, + /// 4D tensor blocked by 1st and 2nd dimension with block size 8 + dnnl_ABcd8b8a, + dnnl_aBCd8b8c, + dnnl_aBCd8b4c, + dnnl_aBCd8c16b2c, + dnnl_ABcde8a16b2a, + dnnl_aCBd8b16c2b, + dnnl_aBCd8c8b, + dnnl_Abcde16a, + dnnl_Abcde32a, + dnnl_ABcde16a16b, + dnnl_BAcde8a16b2a, + /// 4D tensor blocked by 3rd dimension with block size 4 + dnnl_aBCd2b4c2b, + /// 5D tensor blocked by 1st dimension with block size 16 + dnnl_ABcde4b16a4b, + /// 5D tensor blocked by 1st dimension with block size 8 + dnnl_ABcde2b8a4b, + /// 5D tensor blocked by 2nd dimension with block size 16 + dnnl_aBcde16b, + dnnl_ABcde16b16a, + dnnl_aBCde16b16c, + dnnl_aBCde16c16b, + dnnl_aBCde2c8b4c, + dnnl_Abcde4a, + /// 5D tensor blocked by 2nd dimension with block size 32 + dnnl_aBcde32b, + /// 5D tensor blocked by 2nd dimension with block size 4 + dnnl_aBcde4b, + dnnl_ABcde4b4a, + dnnl_ABcde4a4b, + dnnl_aBCde4b4c, + dnnl_aBCde2c4b2c, + dnnl_aBCde4b8c2b, + dnnl_aBCde4c16b4c, + dnnl_aBCde16c16b4c, + dnnl_aBCde16c16b2c, + dnnl_aBCde4c4b, + dnnl_Abcde8a, + dnnl_ABcde8a8b, + dnnl_ABcde8a4b, + dnnl_BAcde16b16a, + /// 5D tensor blocked by 2nd dimension with block size 8 + dnnl_aBcde8b, + dnnl_ABcde8b16a2b, + dnnl_aBCde8b16c2b, + dnnl_aBCde4c8b2c, + dnnl_aCBde8b16c2b, + dnnl_ABcde8b8a, + dnnl_ABcde32a32b, + dnnl_aBCde8b8c, + dnnl_aBCde8b4c, + dnnl_ABc4a8b8a4b, + dnnl_ABcd4a8b8a4b, + dnnl_ABcde4a8b8a4b, + dnnl_BAc4b8a8b4a, + dnnl_BAcd4b8a8b4a, + dnnl_BAcde4b8a8b4a, + dnnl_ABcd2a8b8a2b, + dnnl_aBCd4b8c8b4c, + dnnl_aBCde4b8c8b4c, + dnnl_aBCde2b8c8b2c, + dnnl_aBCde8c16b2c, + dnnl_aBCde8c8b, + /// 5D tensor blocked by 3rd dimension with block size 4 + dnnl_aBCde2b4c2b, + /// 6D tensor blocked by 2nd dimension with block size 16 + dnnl_aBcdef16b, + dnnl_aBCdef16b16c, + dnnl_aBCdef16c16b, + dnnl_aBCdef4c16b4c, + /// 6D tensor blocked by 2nd dimension with block size 8 + dnnl_aBCdef2c8b4c, + dnnl_aBCdef4c8b2c, + /// 6D tensor blocked by 3rd dimension with block size 4 + dnnl_aBCdef2b4c2b, + /// 6D tensor blocked by 2nd dimension with block size 4 + dnnl_aBcdef4b, + dnnl_aBCdef4c4b, + dnnl_aBCdef4b4c, + dnnl_aBCdef2c4b2c, + dnnl_aBCdef4b8c2b, + dnnl_aBCdef8b8c, + dnnl_aBCdef8b4c, + dnnl_aBCdef8c16b2c, + dnnl_aBCdef4b8c8b4c, + dnnl_aBCdef8b16c2b, + dnnl_aCBdef8b16c2b, + dnnl_aBCdef8c8b, + dnnl_aBdc16b, + dnnl_aBdC16b2c, + dnnl_aBdC16b4c, + dnnl_aBdc4b, + dnnl_aBdc8b, + dnnl_aBdec16b, + dnnl_aBdeC16b2c, + dnnl_aBdeC16b4c, + dnnl_aBdec32b, + dnnl_aBdec4b, + dnnl_aBdec8b, + dnnl_aBdefc16b, + dnnl_aBdefC16b2c, + dnnl_aCBdef16c16b, + dnnl_aBdefc4b, + dnnl_aBdefc8b, + dnnl_Abcdef16a, + dnnl_Abcdef32a, + dnnl_aBedc16b, + dnnl_Acb16a, + dnnl_AcB16a2b, + dnnl_AcB16a4b, + dnnl_Acb4a, + dnnl_Acb8a, + dnnl_aCBd16b16c, + dnnl_aCBd16c16b, + dnnl_aCBde16b16c, + dnnl_aCBde16c16b, + dnnl_Acdb16a, + dnnl_AcdB16a2b, + dnnl_AcdB16a4b, + dnnl_Acdb32a, + dnnl_Acdb4a, + dnnl_Acdb8a, + dnnl_Acdeb16a, + dnnl_AcdeB16a2b, + dnnl_Acdeb4a, + dnnl_Acdeb8a, + dnnl_Adcb16a, + dnnl_BAc16a16b, + dnnl_BAc16b16a, + dnnl_BAcd16a16b, + dnnl_BAcd16b16a, + dnnl_aCBd4c8b8c4b, + dnnl_aCBde4c8b8c4b, + dnnl_aCBdef4c8b8c4b, + dnnl_BAcde16a16b, + dnnl_aCBdef16b16c, + dnnl_ABc16b32a, + dnnl_ABc16b64a, + dnnl_ABc4b32a4b, + dnnl_ABc4b64a4b, + dnnl_ABc8b32a2b, + dnnl_ABc8b64a2b, + dnnl_AB16b16a, + dnnl_AB16b32a, + dnnl_AB16b64a, + dnnl_AB8b16a2b, + dnnl_AB8b32a2b, + dnnl_AB8b64a2b, + dnnl_AB4b16a4b, + dnnl_AB4b32a4b, + dnnl_AB4b64a4b, + dnnl_AB16b16a4b, + dnnl_ABcd16b32a, + dnnl_ABcd16b64a, + dnnl_ABcd4b32a4b, + dnnl_ABcd4b64a4b, + dnnl_ABcd8b32a2b, + dnnl_ABcd8b64a2b, + dnnl_ABcde4b32a4b, + dnnl_ABcde4b64a4b, + dnnl_ABcde16b16a4b, + dnnl_ABcde16b16a2b, + dnnl_ABcde16b32a, + dnnl_ABcde16b64a, + dnnl_ABcde8b32a2b, + dnnl_ABcde8b64a2b, + dnnl_aBCdef16c16b4c, + dnnl_aBCdef16c16b2c, + dnnl_AB32a32b8a4b, + dnnl_AB8a4b, + dnnl_AB32a32b8a2b, + dnnl_AB8a2b, + dnnl_abDc32d, + dnnl_abDC32d4c, + dnnl_abdEc32e, + dnnl_abdEC32e2c, + dnnl_abdEC32e4c, + dnnl_aBdefC16b4c, + dnnl_AcdeB16a4b, + dnnl_ABcd16a16b2a, + dnnl_ABc16a16b2a, + dnnl_aBCd16b16c2b, + dnnl_aBCde16b16c2b, + dnnl_Acb32a, + dnnl_AcB32a2b, + dnnl_AcB32a4b, + dnnl_Acb48a, + dnnl_AcB48a2b, + dnnl_AcB48a4b, + dnnl_Acb64a, + dnnl_AcB64a2b, + dnnl_AcB64a4b, + dnnl_cBa2b, + dnnl_cBa4b, + dnnl_aBdc32b, + dnnl_aBdC32b2c, + dnnl_aBdC32b4c, + dnnl_aBdc48b, + dnnl_aBdC48b2c, + dnnl_aBdC48b4c, + dnnl_aBdc64b, + dnnl_aBdC64b2c, + dnnl_aBdC64b4c, + dnnl_adCb2c, + dnnl_adCb4c, + dnnl_AcdB32a2b, + dnnl_AcdB32a4b, + dnnl_Acdb48a, + dnnl_AcdB48a2b, + dnnl_AcdB48a4b, + dnnl_Acdb64a, + dnnl_AcdB64a2b, + dnnl_AcdB64a4b, + dnnl_cdBa2b, + dnnl_cdBa4b, + dnnl_aBdeC32b2c, + dnnl_aBdeC32b4c, + dnnl_aBdec48b, + dnnl_aBdeC48b2c, + dnnl_aBdeC48b4c, + dnnl_aBdec64b, + dnnl_aBdeC64b2c, + dnnl_aBdeC64b4c, + dnnl_adeCb2c, + dnnl_adeCb4c, + dnnl_Acdeb32a, + dnnl_AcdeB32a2b, + dnnl_AcdeB32a4b, + dnnl_Acdeb48a, + dnnl_AcdeB48a2b, + dnnl_AcdeB48a4b, + dnnl_Acdeb64a, + dnnl_AcdeB64a2b, + dnnl_AcdeB64a4b, + dnnl_cdeBa2b, + dnnl_cdeBa4b, + dnnl_aBdefc32b, + dnnl_aBdefC32b2c, + dnnl_aBdefC32b4c, + dnnl_aBdefc48b, + dnnl_aBdefC48b2c, + dnnl_aBdefC48b4c, + dnnl_aBdefc64b, + dnnl_aBdefC64b2c, + dnnl_aBdefC64b4c, + dnnl_adefCb2c, + dnnl_adefCb4c, + dnnl_AB16b32a4b, + dnnl_AB16b48a4b, + dnnl_AB16b64a4b, + dnnl_AB16b16a2b, + dnnl_AB16b32a2b, + dnnl_AB16b48a2b, + dnnl_AB16b64a2b, + dnnl_ABc16b32a4b, + dnnl_ABc16b48a4b, + dnnl_ABc16b64a4b, + dnnl_ABc16b32a2b, + dnnl_ABc16b48a2b, + dnnl_ABc16b64a2b, + dnnl_ABcd16b32a4b, + dnnl_ABcd16b48a4b, + dnnl_ABcd16b64a4b, + dnnl_ABcd16b32a2b, + dnnl_ABcd16b48a2b, + dnnl_ABcd16b64a2b, + dnnl_ABcde16b32a4b, + dnnl_ABcde16b48a4b, + dnnl_ABcde16b64a4b, + dnnl_ABcde16b32a2b, + dnnl_ABcde16b48a2b, + dnnl_ABcde16b64a2b, + dnnl_ABc32a16b, + dnnl_ABcd32a16b, + dnnl_ABcde32a16b, + dnnl_AB48a16b, + dnnl_AB48a32b, + dnnl_ABc40a16b, + dnnl_ABc40a32b, + dnnl_aBC48b16c, + dnnl_aBC48b32c, + dnnl_ABcd40a16b, + dnnl_ABcd40a32b, + dnnl_abCd32c, + dnnl_abdCe32c, + dnnl_abdCE32c2e, + dnnl_BA16a16b2a, + dnnl_BA16a32b2a, + dnnl_BA16a48b2a, + dnnl_BA16a64b2a, + dnnl_BA16a16b4a, + dnnl_BA16a32b4a, + dnnl_BA16a48b4a, + dnnl_BA16a64b4a, + dnnl_ABcd8a2b, + dnnl_aBdeC16c16b2c, + dnnl_aBdeC16c16b4c, + dnnl_aBdefC16c16b2c, + dnnl_AcB16b16a2b, + dnnl_AcB16b16a4b, + dnnl_AcdB16b16a2b, + dnnl_AcdB16b16a4b, + dnnl_AcdeB16b16a2b, + dnnl_aBdefC16c16b4c, + dnnl_AcdeB16b16a4b, + dnnl_AcB16b32a2b, + dnnl_AcB16b32a4b, + dnnl_AcB16b48a2b, + dnnl_AcB16b48a4b, + dnnl_AcB16b64a2b, + dnnl_AcB16b64a4b, + dnnl_aBdC16c16b2c, + dnnl_aBdC16c16b4c, + dnnl_aBdC16c32b2c, + dnnl_aBdC16c32b4c, + dnnl_aBdC16c48b2c, + dnnl_aBdC16c48b4c, + dnnl_aBdC16c64b2c, + dnnl_aBdC16c64b4c, + dnnl_AcdB16b32a2b, + dnnl_AcdB16b32a4b, + dnnl_AcdB16b48a2b, + dnnl_AcdB16b48a4b, + dnnl_AcdB16b64a2b, + dnnl_AcdB16b64a4b, + dnnl_aBdeC16c32b2c, + dnnl_aBdeC16c32b4c, + dnnl_aBdeC16c48b2c, + dnnl_aBdeC16c48b4c, + dnnl_aBdeC16c64b2c, + dnnl_aBdeC16c64b4c, + dnnl_AcdeB16b32a2b, + dnnl_AcdeB16b32a4b, + dnnl_AcdeB16b48a2b, + dnnl_AcdeB16b48a4b, + dnnl_AcdeB16b64a2b, + dnnl_AcdeB16b64a4b, + dnnl_aBdefC16c32b2c, + dnnl_aBdefC16c32b4c, + dnnl_aBdefC16c48b2c, + dnnl_aBdefC16c48b4c, + dnnl_aBdefC16c64b2c, + dnnl_aBdefC16c64b4c, + dnnl_decbA16a, + dnnl_ABc4a2b, + dnnl_ABc8a2b, + dnnl_aBCd8b2c, + dnnl_ABcde4a2b, + dnnl_ABcde8a2b, + dnnl_ABcde40a16b, + dnnl_ABcde40a32b, + dnnl_aBCde8b2c, + dnnl_ABcde4a8b8a2b, + dnnl_ABcd4a8b8a2b, + dnnl_ABc4a8b8a2b, + dnnl_aBCdef4b8c8b2c, + dnnl_aBCde4b8c8b2c, + dnnl_aBCd4b8c8b2c, + dnnl_BAcde4b8a8b2a, + dnnl_BAcd4b8a8b2a, + dnnl_BAc4b8a8b2a, + dnnl_aCBdef4c8b8c2b, + dnnl_aCBde4c8b8c2b, + dnnl_aCBd4c8b8c2b, + dnnl_aBCdef8b2c, + dnnl_AB32a16b, + dnnl_AB32a32b, + dnnl_BA4b8a8b2a, + dnnl_BA4b8a8b4a, + dnnl_aBC32b16c, + dnnl_aBC32b32c, + dnnl_aCB4c8b8c2b, + dnnl_aCB4c8b8c4b, + dnnl_ABcd4a2b, + dnnl_ABc2b8a16b4a, + dnnl_ABcd2b8a16b4a, + dnnl_ABcde2b8a16b4a, + dnnl_ABc2a8b16a4b, + dnnl_ABc2a8b16a2b, + dnnl_ABc2b32a8b, + dnnl_ABcd2a8b16a4b, + dnnl_ABcd2a8b16a2b, + dnnl_aCBd2c8b16c2b, + dnnl_ABcd2b32a8b, + dnnl_aBCd2c8b16c2b, + dnnl_ABcde2a8b16a4b, + dnnl_ABcde2a8b16a2b, + dnnl_aCBde2c8b16c2b, + dnnl_ABcde2b32a8b, + dnnl_aBC2b8c16b2c, + dnnl_aBCd2b8c16b2c, + dnnl_aBCde2b8c16b2c, + dnnl_aBCdef2b8c16b2c, + dnnl_BAcde2b8a16b4a, + dnnl_BAcd2b8a16b4a, + dnnl_BAc2b8a16b4a, + dnnl_BAcde2b8a16b2a, + dnnl_BAcd2b8a16b2a, + dnnl_BAc2b8a16b2a, + dnnl_aBCde2c8b16c2b, + dnnl_aBCdef2c8b16c2b, + dnnl_aCBdef2c8b16c2b, + dnnl_aBCd2b8c16b4c, + dnnl_aBCde2b8c16b4c, + dnnl_BA4b8a16b2a, + dnnl_BA4b8a16b4a, + dnnl_aCB4c8b16c2b, + dnnl_aCB4c8b16c4b, + dnnl_BA16a16b, + dnnl_BA16a32b, + dnnl_BA16a48b, + dnnl_BA16a64b, + dnnl_aCB16c2b, + dnnl_aCB16c4b, + dnnl_BA16b2a, + dnnl_BA16b4a, + dnnl_aBC16b16c, + dnnl_aBC16b32c, + dnnl_AB16a16b, + dnnl_AB16a32b, + dnnl_ABcde16a16b2a, + dnnl_aBCdef16b16c2b, + dnnl_Acedb16a, + dnnl_aBdfec16b, + dnnl_abdEC64e2c, + dnnl_abdEC64e4c, + dnnl_aCB16b16c, + dnnl_aCB16b32c, + dnnl_aCB16b48c, + dnnl_aCB16b64c, + dnnl_aCB16b16c2b, + dnnl_aCB16b32c2b, + dnnl_aCB16b48c2b, + dnnl_aCB16b64c2b, + dnnl_aCB16b16c4b, + dnnl_aCB16b32c4b, + dnnl_aCB16b48c4b, + dnnl_aCB16b64c4b, + dnnl_abCd4c, + dnnl_abCde4c, + dnnl_abCdef4c, + dnnl_abCde32c, + dnnl_abCdef32c, + dnnl_ABcd16a32b, + dnnl_decbA8a, + dnnl_aCdefB16b32c2b, + dnnl_aCdefB16b32c4b, + dnnl_aCdefB16b48c2b, + dnnl_aCdefB16b48c4b, + dnnl_aCdefB16b64c2b, + dnnl_aCdefB16b64c4b, + dnnl_BcdeA16a32b2a, + dnnl_BcdeA16a32b4a, + dnnl_BcdeA16a48b2a, + dnnl_BcdeA16a48b4a, + dnnl_BcdeA16a64b2a, + dnnl_BcdeA16a64b4a, + dnnl_aCdefb32c, + dnnl_aCdefB32c2b, + dnnl_aCdefB32c4b, + dnnl_aCdefb48c, + dnnl_aCdefB48c2b, + dnnl_aCdefB48c4b, + dnnl_aCdefb64c, + dnnl_aCdefB64c2b, + dnnl_aCdefB64c4b, + dnnl_Bcdea32b, + dnnl_BcdeA32b2a, + dnnl_BcdeA32b4a, + dnnl_Bcdea48b, + dnnl_BcdeA48b2a, + dnnl_BcdeA48b4a, + dnnl_Bcdea64b, + dnnl_BcdeA64b2a, + dnnl_BcdeA64b4a, + dnnl_Bca32b, + dnnl_BcA32b2a, + dnnl_BcA32b4a, + dnnl_Bca48b, + dnnl_BcA48b2a, + dnnl_BcA48b4a, + dnnl_Bca64b, + dnnl_BcA64b2a, + dnnl_BcA64b4a, + dnnl_aCdb32c, + dnnl_aCdB32c2b, + dnnl_aCdB32c4b, + dnnl_aCdb48c, + dnnl_aCdB48c2b, + dnnl_aCdB48c4b, + dnnl_aCdb64c, + dnnl_aCdB64c2b, + dnnl_aCdB64c4b, + dnnl_BcA16a16b2a, + dnnl_BcA16a16b4a, + dnnl_BcdA16a16b2a, + dnnl_BcdA16a16b4a, + dnnl_BcdeA16a16b2a, + dnnl_BcdeA16a16b4a, + dnnl_aCdB16b16c2b, + dnnl_aCdB16b16c4b, + dnnl_aCdeB16b16c2b, + dnnl_aCdeB16b16c4b, + dnnl_aCdefB16b16c2b, + dnnl_aCdefB16b16c4b, + dnnl_BcA16a32b2a, + dnnl_BcA16a32b4a, + dnnl_BcA16a48b2a, + dnnl_BcA16a48b4a, + dnnl_BcA16a64b2a, + dnnl_BcA16a64b4a, + dnnl_aCdB16b32c2b, + dnnl_aCdB16b32c4b, + dnnl_aCdB16b48c2b, + dnnl_aCdB16b48c4b, + dnnl_aCdB16b64c2b, + dnnl_aCdB16b64c4b, + dnnl_BcdA16a32b2a, + dnnl_BcdA16a32b4a, + dnnl_BcdA16a48b2a, + dnnl_BcdA16a48b4a, + dnnl_BcdA16a64b2a, + dnnl_BcdA16a64b4a, + dnnl_aCdeB16b32c2b, + dnnl_aCdeB16b32c4b, + dnnl_aCdeB16b48c2b, + dnnl_aCdeB16b48c4b, + dnnl_aCdeB16b64c2b, + dnnl_aCdeB16b64c4b, + dnnl_Bca16b, + dnnl_BcA16b2a, + dnnl_BcA16b4a, + dnnl_Bcda16b, + dnnl_BcdA16b2a, + dnnl_BcdA16b4a, + dnnl_Bcdea16b, + dnnl_BcdeA16b2a, + dnnl_BcdeA16b4a, + dnnl_aCdb16c, + dnnl_aCdB16c2b, + dnnl_aCdB16c4b, + dnnl_aCdeb16c, + dnnl_aCdeB16c2b, + dnnl_aCdeB16c4b, + dnnl_aCdefb16c, + dnnl_aCdefB16c2b, + dnnl_aCdefB16c4b, + dnnl_Bcda32b, + dnnl_BcdA32b2a, + dnnl_BcdA32b4a, + dnnl_Bcda48b, + dnnl_BcdA48b2a, + dnnl_BcdA48b4a, + dnnl_Bcda64b, + dnnl_BcdA64b2a, + dnnl_BcdA64b4a, + dnnl_aCdeb32c, + dnnl_aCdeB32c2b, + dnnl_aCdeB32c4b, + dnnl_aCdeb48c, + dnnl_aCdeB48c2b, + dnnl_aCdeB48c4b, + dnnl_aCdeb64c, + dnnl_aCdeB64c2b, + dnnl_aCdeB64c4b, + dnnl_Acb24a, + dnnl_Acdb24a, + dnnl_Acdeb24a, + dnnl_aBdc24b, + dnnl_aBdec24b, + dnnl_aBdefc24b, + dnnl_abDc16d, + dnnl_abdEc16e, + dnnl_abdCe16c, + dnnl_AcB24a2b, + dnnl_AcdB24a2b, + dnnl_AcdeB24a2b, + dnnl_aBdC24b2c, + dnnl_aBdeC24b2c, + dnnl_aBdefC24b2c, + dnnl_AcB8a2b, + dnnl_AcdB8a2b, + dnnl_AcdeB8a2b, + dnnl_aBdC8b2c, + dnnl_aBdeC8b2c, + dnnl_aBdefC8b2c, + dnnl_AB8b32a, + dnnl_ABc8b32a, + dnnl_ABcd8b32a, + dnnl_ABcde8b32a, + dnnl_AB8b24a, + dnnl_ABc8b24a, + dnnl_ABcd8b24a, + dnnl_ABcde8b24a, + dnnl_AB8b16a, + dnnl_ABc8b16a, + dnnl_ABcd8b16a, + dnnl_ABcde8b16a, + dnnl_AB8b8a, + dnnl_AB4b8a4b, + dnnl_AB4b24a4b, + dnnl_ABc4b8a4b, + dnnl_ABc4b24a4b, + dnnl_ABcd4b8a4b, + dnnl_ABcd4b24a4b, + dnnl_ABcde4b8a4b, + dnnl_ABcde4b24a4b, + dnnl_AB8b24a2b, + dnnl_ABc8b24a2b, + dnnl_ABcd8b24a2b, + dnnl_ABcde8b24a2b, + dnnl_AB8b8a2b, + dnnl_ABc8b8a2b, + dnnl_ABcd8b8a2b, + dnnl_ABcde8b8a2b, + dnnl_AcB24a4b, + dnnl_AcdB24a4b, + dnnl_AcdeB24a4b, + dnnl_aBdC24b4c, + dnnl_aBdeC24b4c, + dnnl_aBdefC24b4c, + dnnl_AcB8a4b, + dnnl_AcdB8a4b, + dnnl_AcdeB8a4b, + dnnl_aBdC8b4c, + dnnl_aBdeC8b4c, + dnnl_aBdefC8b4c, + dnnl_Bca8b, + dnnl_BcA8b2a, + dnnl_Bcda8b, + dnnl_BcdA8b2a, + dnnl_Bcdea8b, + dnnl_BcdeA8b2a, + dnnl_aCdb8c, + dnnl_aCdB8c2b, + dnnl_aCdeb8c, + dnnl_aCdeB8c2b, + dnnl_aCdefb8c, + dnnl_aCdefB8c2b, + dnnl_Bca24b, + dnnl_BcA24b2a, + dnnl_Bcda24b, + dnnl_BcdA24b2a, + dnnl_Bcdea24b, + dnnl_BcdeA24b2a, + dnnl_aCdb24c, + dnnl_aCdB24c2b, + dnnl_aCdeb24c, + dnnl_aCdeB24c2b, + dnnl_aCdefb24c, + dnnl_aCdefB24c2b, + dnnl_BcA8b4a, + dnnl_BcdA8b4a, + dnnl_BcdeA8b4a, + dnnl_aCdB8c4b, + dnnl_aCdeB8c4b, + dnnl_aCdefB8c4b, + dnnl_BcA24b4a, + dnnl_BcdA24b4a, + dnnl_BcdeA24b4a, + dnnl_aCdB24c4b, + dnnl_aCdeB24c4b, + dnnl_aCdefB24c4b, + dnnl_AB16b48a, + dnnl_ABc16b48a, + dnnl_ABcd16b48a, + dnnl_ABcde16b48a, + dnnl_ABc16a4b, + dnnl_ABcd16a4b, + dnnl_ABcde16a4b, + dnnl_defcbA16a, + dnnl_defcbA8a, + dnnl_AcB16b64a, + dnnl_AcdB16b64a, + dnnl_AcdeB16b64a, + dnnl_AcB16b48a, + dnnl_AcdB16b48a, + dnnl_AcdeB16b48a, + dnnl_AcB16b32a, + dnnl_AcdB16b32a, + dnnl_AcdeB16b32a, + dnnl_AcB16b16a, + dnnl_AcdB16b16a, + dnnl_AcdeB16b16a, + dnnl_AcB8b32a, + dnnl_AcdB8b32a, + dnnl_AcdeB8b32a, + dnnl_AcB8b24a, + dnnl_AcdB8b24a, + dnnl_AcdeB8b24a, + dnnl_AcB8b16a, + dnnl_AcdB8b16a, + dnnl_AcdeB8b16a, + dnnl_AcB8b8a, + dnnl_AcdB8b8a, + dnnl_AcdeB8b8a, + dnnl_AcB8b64a2b, + dnnl_AcdB8b64a2b, + dnnl_AcdeB8b64a2b, + dnnl_AcB8b32a2b, + dnnl_AcdB8b32a2b, + dnnl_AcdeB8b32a2b, + dnnl_AcB8b24a2b, + dnnl_AcdB8b24a2b, + dnnl_AcdeB8b24a2b, + dnnl_AcB8b16a2b, + dnnl_AcdB8b16a2b, + dnnl_AcdeB8b16a2b, + dnnl_AcB8b8a2b, + dnnl_AcdB8b8a2b, + dnnl_AcdeB8b8a2b, + dnnl_AcB4b64a4b, + dnnl_AcdB4b64a4b, + dnnl_AcdeB4b64a4b, + dnnl_AcB4b32a4b, + dnnl_AcdB4b32a4b, + dnnl_AcdeB4b32a4b, + dnnl_AcB4b24a4b, + dnnl_AcdB4b24a4b, + dnnl_AcdeB4b24a4b, + dnnl_AcB4b16a4b, + dnnl_AcdB4b16a4b, + dnnl_AcdeB4b16a4b, + dnnl_AcB4b8a4b, + dnnl_AcdB4b8a4b, + dnnl_AcdeB4b8a4b, + dnnl_Ab4a, + dnnl_Ab8a, + dnnl_BA4b4a, + dnnl_BA8b4a, + dnnl_BA2a24b, + dnnl_aCB2b24c, + dnnl_BA2a8b, + dnnl_aCB2b8c, + dnnl_BA8a24b, + dnnl_aCB8b24c, + dnnl_BA8a16b, + dnnl_aCB8b16c, + dnnl_BA8a8b, + dnnl_aCB8b8c, + dnnl_bcad, + dnnl_cabd, + dnnl_dabc, + dnnl_Ab32a, + dnnl_aCBd8b8c, + dnnl_aCBde8b8c, + dnnl_BAc8a8b, + dnnl_BAcd8a8b, + dnnl_BAcde8a8b, + dnnl_aCBdef8b8c, + dnnl_abdEC16e4c, + dnnl_abDC16d4c, + + /// Just a sentinel, not real memory format tag. Must be changed after new + /// format tag is added. + dnnl_format_tag_last, + + // Aliases + + /// 1D tensor, an alias to #dnnl_a + dnnl_x = dnnl_a, + /// 2D CNN activations tensor, an alias to #dnnl_ab + dnnl_nc = dnnl_ab, + /// 2D CNN activations tensor, an alias to #dnnl_ba + dnnl_cn = dnnl_ba, + /// 2D RNN statistics tensor, an alias to #dnnl_ab + dnnl_tn = dnnl_ab, + /// 2D RNN statistics tensor, an alias to #dnnl_ba + dnnl_nt = dnnl_ba, + /// 3D CNN activations tensor, an alias to #dnnl_abc + dnnl_ncw = dnnl_abc, + /// 3D CNN activations tensor, an alias to #dnnl_acb + dnnl_nwc = dnnl_acb, + /// 4D CNN activations tensor, an alias to #dnnl_abcd + dnnl_nchw = dnnl_abcd, + /// 4D CNN activations tensor, an alias to #dnnl_acdb + dnnl_nhwc = dnnl_acdb, + /// 4D CNN activations tensor, an alias to #dnnl_bcda + dnnl_chwn = dnnl_bcda, + /// 5D CNN activations tensor, an alias to #dnnl_abcde + dnnl_ncdhw = dnnl_abcde, + /// 5D CNN activations tensor, an alias to #dnnl_acdeb + dnnl_ndhwc = dnnl_acdeb, + + /// 2D CNN weights tensor, an alias to #dnnl_ab + dnnl_oi = dnnl_ab, + /// 2D CNN weights tensor, an alias to #dnnl_ba + dnnl_io = dnnl_ba, + /// 3D CNN weights tensor, an alias to #dnnl_abc + dnnl_oiw = dnnl_abc, + /// 3D CNN weights tensor, an alias to #dnnl_acb + dnnl_owi = dnnl_acb, + /// 3D CNN weights tensor, an alias to #dnnl_cba + dnnl_wio = dnnl_cba, + /// 3D CNN weights tensor, an alias to #dnnl_cab + dnnl_woi = dnnl_cab, + /// 3D CNN weights tensor, an alias to #dnnl_bca + dnnl_iwo = dnnl_bca, + /// 4D CNN weights tensor, an alias to #dnnl_abcd + dnnl_oihw = dnnl_abcd, + /// 4D CNN weights tensor, an alias to #dnnl_cdba + dnnl_hwio = dnnl_cdba, + /// 4D CNN weights tensor, an alias to #dnnl_cdab + dnnl_hwoi = dnnl_cdab, + /// 4D CNN weights tensor, an alias to #dnnl_acdb + dnnl_ohwi = dnnl_acdb, + /// 4D CNN weights tensor, an alias to #dnnl_bcda + dnnl_ihwo = dnnl_bcda, + /// 4D CNN weights tensor, an alias to #dnnl_bacd + dnnl_iohw = dnnl_bacd, + /// 5D CNN weights tensor, an alias to #dnnl_abcde + dnnl_oidhw = dnnl_abcde, + /// 5D CNN weights tensor, an alias to #dnnl_bacde + dnnl_iodhw = dnnl_bacde, + /// 5D CNN weights tensor, an alias to #dnnl_cdeba + dnnl_dhwio = dnnl_cdeba, + /// 5D CNN weights tensor, an alias to #dnnl_cdeab + dnnl_dhwoi = dnnl_cdeab, + /// 5D CNN weights tensor, an alias to #dnnl_acdeb + dnnl_odhwi = dnnl_acdeb, + /// 5D CNN weights tensor, an alias to #dnnl_bcdea + dnnl_idhwo = dnnl_bcdea, + + /// 4D CNN weights tensor (incl. groups), an alias to #dnnl_abcd + dnnl_goiw = dnnl_abcd, + /// 4D CNN weights tensor (incl. groups), an alias to #dnnl_abdc + dnnl_gowi = dnnl_abdc, + /// 4D CNN weights tensor (incl. groups), an alias to #dnnl_dcab + dnnl_wigo = dnnl_dcab, + /// 5D CNN weights tensor (incl. groups), an alias to #dnnl_abcde + dnnl_goihw = dnnl_abcde, + /// 5D CNN weights tensor (incl. groups), an alias to #dnnl_abdec + dnnl_gohwi = dnnl_abdec, + /// 5D CNN weights tensor (incl. groups), an alias to #dnnl_decab + dnnl_hwigo = dnnl_decab, + /// 5D CNN weights tensor (incl. groups), an alias to #dnnl_acbde + dnnl_giohw = dnnl_acbde, + /// 6D CNN weights tensor (incl. groups), an alias to #dnnl_abcdef + dnnl_goidhw = dnnl_abcdef, + /// 6D CNN weights tensor (incl. groups), an alias to #dnnl_abdefc + dnnl_godhwi = dnnl_abdefc, + /// 6D CNN weights tensor (incl. groups), an alias to #dnnl_acbdef + dnnl_giodhw = dnnl_acbdef, + /// 6D CNN weights tensor (incl. groups), an alias to #dnnl_defcab + dnnl_dhwigo = dnnl_defcab, + + /// 3D RNN data tensor in the format (seq_length, batch, input channels), + /// an alias to #dnnl_abc. + dnnl_tnc = dnnl_abc, + /// 3D RNN data tensor in the format (batch, seq_length, input channels), + /// an alias to #dnnl_bac. + dnnl_ntc = dnnl_bac, + /// 4D RNN states tensor in the format (num_layers, num_directions, + /// batch, state channels), an alias to #dnnl_abcd. + dnnl_ldnc = dnnl_abcd, + /// 5D RNN weights tensor in the format (num_layers, num_directions, + /// input_channels, num_gates, output_channels), an alias to #dnnl_abcde. + /// + /// - For LSTM cells, the gates order is input, forget, candidate + /// and output gate. + /// - For GRU cells, the gates order is update, reset and output gate. + dnnl_ldigo = dnnl_abcde, + /// 5D RNN weights tensor in the format (num_layers, num_directions, + /// num_gates, output_channels, input_channels), an alias to #dnnl_abdec. + /// + /// - For LSTM cells, the gates order is input, forget, candidate + /// and output gate. + /// - For GRU cells, the gates order is update, reset and output gate. + dnnl_ldgoi = dnnl_abdec, + /// 4D LSTM projection tensor in the format (num_layers, num_directions, + /// num_channels_in_hidden_state, num_channels_in_recurrent_projection), + /// an alias to #dnnl_abcd. + dnnl_ldio = dnnl_abcd, + /// 4D LSTM projection tensor in the format (num_layers, num_directions, + /// num_channels_in_recurrent_projection, num_channels_in_hidden_state), + /// an alias to #dnnl_abdc. + dnnl_ldoi = dnnl_abdc, + /// 4D RNN bias tensor in the format (num_layers, num_directions, + /// num_gates, output_channels), an alias to #dnnl_abcd. + /// + /// - For LSTM cells, the gates order is input, forget, candidate + /// and output gate. + /// - For GRU cells, the gates order is update, reset and output gate. + dnnl_ldgo = dnnl_abcd, + /// 5D LSTM projection tensor + dnnl_ldOi16o = dnnl_abDc16d, + dnnl_ldOi32o = dnnl_abDc32d, + dnnl_ldOI16o4i = dnnl_abDC16d4c, + dnnl_ldOI32o4i = dnnl_abDC32d4c, + dnnl_ldIo32i = dnnl_abCd32c, + /// 6D RNN weights tensor + dnnl_ldgOi16o = dnnl_abdEc16e, + dnnl_ldgOI16o4i = dnnl_abdEC16e4c, + dnnl_ldgOi32o = dnnl_abdEc32e, + dnnl_ldgOI32o2i = dnnl_abdEC32e2c, + dnnl_ldgOI32o4i = dnnl_abdEC32e4c, + dnnl_ldgOI64o2i = dnnl_abdEC64e2c, + dnnl_ldgOI64o4i = dnnl_abdEC64e4c, + dnnl_ldgIo16i = dnnl_abdCe16c, + dnnl_ldgIo32i = dnnl_abdCe32c, + dnnl_ldgIO32i2o = dnnl_abdCE32c2e, + + // Opaque data types, are not to be used explicitly + + // data + /// 5D CNN activations tensor blocked by channels with block size 32, + /// an alias to #dnnl_aBcde32b + dnnl_nCdhw32c = dnnl_aBcde32b, + /// 5D CNN activations tensor blocked by channels with block size 16, + /// an alias to #dnnl_aBcde16b + dnnl_nCdhw16c = dnnl_aBcde16b, + /// 5D CNN activations tensor blocked by channels with block size 4, + /// an alias to #dnnl_aBcde4b + dnnl_nCdhw4c = dnnl_aBcde4b, + /// 5D CNN activations tensor blocked by channels with block size 8, + /// an alias to #dnnl_aBcde8b + dnnl_nCdhw8c = dnnl_aBcde8b, + /// 4D CNN activations tensor blocked by channels with block size 32, + /// an alias to #dnnl_aBcd32b + dnnl_nChw32c = dnnl_aBcd32b, + /// 4D CNN activations tensor blocked by channels with block size 16, + /// an alias to #dnnl_aBcd16b + dnnl_nChw16c = dnnl_aBcd16b, + /// 4D CNN activations tensor blocked by channels with block size 4, + /// an alias to #dnnl_aBcd4b + dnnl_nChw4c = dnnl_aBcd4b, + /// 4D CNN activations tensor blocked by channels with block size 8, + /// an alias to #dnnl_aBcd8b + dnnl_nChw8c = dnnl_aBcd8b, + /// 3D CNN activations tensor blocked by channels with block size 32, + /// an alias to #dnnl_aBc32b + dnnl_nCw32c = dnnl_aBc32b, + /// 3D CNN activations tensor blocked by channels with block size 16, + /// an alias to #dnnl_aBc16b + dnnl_nCw16c = dnnl_aBc16b, + /// 3D CNN activations tensor blocked by channels with block size 4, + /// an alias to #dnnl_aBc4b + dnnl_nCw4c = dnnl_aBc4b, + /// 3D CNN activations tensor blocked by channels with block size 8, + /// an alias to #dnnl_aBc8b + dnnl_nCw8c = dnnl_aBc8b, + dnnl_NCw16n16c = dnnl_ABc16a16b, + dnnl_NCdhw16n16c = dnnl_ABcde16a16b, + dnnl_NChw16n16c = dnnl_ABcd16a16b, + dnnl_NCw32n16c = dnnl_ABc32a16b, + dnnl_NChw32n16c = dnnl_ABcd32a16b, + dnnl_NChw16n32c = dnnl_ABcd16a32b, + dnnl_NCdhw32n16c = dnnl_ABcde32a16b, + dnnl_NCw32n32c = dnnl_ABc32a32b, + dnnl_NChw32n32c = dnnl_ABcd32a32b, + dnnl_NCdhw32n32c = dnnl_ABcde32a32b, + + // weights, 2D + dnnl_OI16i16o = dnnl_AB16b16a, + dnnl_OI16i32o = dnnl_AB16b32a, + dnnl_OI16i48o = dnnl_AB16b48a, + dnnl_OI16i64o = dnnl_AB16b64a, + dnnl_OI8i8o2i = dnnl_AB8b8a2b, + dnnl_OI8i16o2i = dnnl_AB8b16a2b, + dnnl_OI8i24o2i = dnnl_AB8b24a2b, + dnnl_OI8i32o2i = dnnl_AB8b32a2b, + dnnl_OI8i64o2i = dnnl_AB8b64a2b, + dnnl_OI4i8o4i = dnnl_AB4b8a4b, + dnnl_OI4i16o4i = dnnl_AB4b16a4b, + dnnl_OI4i24o4i = dnnl_AB4b24a4b, + dnnl_OI4i32o4i = dnnl_AB4b32a4b, + dnnl_OI4i64o4i = dnnl_AB4b64a4b, + dnnl_OI16i16o4i = dnnl_AB16b16a4b, + dnnl_OI8i32o = dnnl_AB8b32a, + dnnl_OI8i24o = dnnl_AB8b24a, + dnnl_OI8i16o = dnnl_AB8b16a, + dnnl_OI8i8o = dnnl_AB8b8a, + + // weights, 3D + dnnl_IOw8o8i = dnnl_BAc8a8b, + dnnl_IOw16o16i = dnnl_BAc16a16b, + dnnl_IOw16i16o = dnnl_BAc16b16a, + dnnl_OIw16i16o = dnnl_ABc16b16a, + dnnl_OwI16i16o = dnnl_AcB16b16a, + dnnl_OIw16i32o = dnnl_ABc16b32a, + dnnl_OwI16i32o = dnnl_AcB16b32a, + dnnl_OIw16i48o = dnnl_ABc16b48a, + dnnl_OwI16i48o = dnnl_AcB16b48a, + dnnl_OIw16i64o = dnnl_ABc16b64a, + dnnl_OwI16i64o = dnnl_AcB16b64a, + dnnl_OIw16o16i = dnnl_ABc16a16b, + dnnl_Oiw16o = dnnl_Abc16a, + dnnl_OIw4i8o4i = dnnl_ABc4b8a4b, + dnnl_OwI4i8o4i = dnnl_AcB4b8a4b, + dnnl_OIw4i16o4i = dnnl_ABc4b16a4b, + dnnl_OwI4i16o4i = dnnl_AcB4b16a4b, + dnnl_OIw4i24o4i = dnnl_ABc4b24a4b, + dnnl_OwI4i24o4i = dnnl_AcB4b24a4b, + dnnl_OIw4i32o4i = dnnl_ABc4b32a4b, + dnnl_OwI4i32o4i = dnnl_AcB4b32a4b, + dnnl_OIw4i64o4i = dnnl_ABc4b64a4b, + dnnl_OwI4i64o4i = dnnl_AcB4b64a4b, + dnnl_OIw2i8o4i = dnnl_ABc2b8a4b, + dnnl_OIw16i16o4i = dnnl_ABc16b16a4b, + dnnl_OIw16i16o2i = dnnl_ABc16b16a2b, + dnnl_OIw16o16i2o = dnnl_ABc16a16b2a, + dnnl_OIw4i4o = dnnl_ABc4b4a, + dnnl_OIw4o4i = dnnl_ABc4a4b, + dnnl_Oiw4o = dnnl_Abc4a, + dnnl_OIw8i8o2i = dnnl_ABc8b8a2b, + dnnl_OwI8i8o2i = dnnl_AcB8b8a2b, + dnnl_OIw8i16o2i = dnnl_ABc8b16a2b, + dnnl_OwI8i16o2i = dnnl_AcB8b16a2b, + dnnl_OIw8i24o2i = dnnl_ABc8b24a2b, + dnnl_OwI8i24o2i = dnnl_AcB8b24a2b, + dnnl_OIw8i32o2i = dnnl_ABc8b32a2b, + dnnl_OwI8i32o2i = dnnl_AcB8b32a2b, + dnnl_OIw8i64o2i = dnnl_ABc8b64a2b, + dnnl_OwI8i64o2i = dnnl_AcB8b64a2b, + dnnl_OIw8i8o = dnnl_ABc8b8a, + dnnl_OwI8i8o = dnnl_AcB8b8a, + dnnl_OIw8o16i2o = dnnl_ABc8a16b2a, + dnnl_IOw8o16i2o = dnnl_BAc8a16b2a, + dnnl_OIw8o8i = dnnl_ABc8a8b, + dnnl_OIw8o4i = dnnl_ABc8a4b, + dnnl_Owi16o = dnnl_Acb16a, + dnnl_OwI16o2i = dnnl_AcB16a2b, + dnnl_OwI16o4i = dnnl_AcB16a4b, + dnnl_Iwo8i = dnnl_Bca8b, + dnnl_IwO8i2o = dnnl_BcA8b2a, + dnnl_IwO8i4o = dnnl_BcA8b4a, + dnnl_Iwo16i = dnnl_Bca16b, + dnnl_IwO16i2o = dnnl_BcA16b2a, + dnnl_IwO16i4o = dnnl_BcA16b4a, + dnnl_Iwo24i = dnnl_Bca24b, + dnnl_IwO24i2o = dnnl_BcA24b2a, + dnnl_IwO24i4o = dnnl_BcA24b4a, + dnnl_Owi4o = dnnl_Acb4a, + dnnl_Owi8o = dnnl_Acb8a, + dnnl_OwI8o2i = dnnl_AcB8a2b, + dnnl_OIw8i32o = dnnl_ABc8b32a, + dnnl_OwI8i32o = dnnl_AcB8b32a, + dnnl_OIw8i24o = dnnl_ABc8b24a, + dnnl_OwI8i24o = dnnl_AcB8b24a, + dnnl_OIw8i16o = dnnl_ABc8b16a, + dnnl_OwI8i16o = dnnl_AcB8b16a, + dnnl_OwI8o4i = dnnl_AcB8a4b, + + // weights, 4D + dnnl_IOhw16i16o = dnnl_BAcd16b16a, + dnnl_IOhw8o8i = dnnl_BAcd8a8b, + dnnl_IOhw16o16i = dnnl_BAcd16a16b, + dnnl_Ohwi16o = dnnl_Acdb16a, + dnnl_OhwI16o2i = dnnl_AcdB16a2b, + dnnl_OhwI16o4i = dnnl_AcdB16a4b, + dnnl_Ihwo8i = dnnl_Bcda8b, + dnnl_IhwO8i2o = dnnl_BcdA8b2a, + dnnl_IhwO8i4o = dnnl_BcdA8b4a, + dnnl_Ihwo16i = dnnl_Bcda16b, + dnnl_IhwO16i2o = dnnl_BcdA16b2a, + dnnl_IhwO16i4o = dnnl_BcdA16b4a, + dnnl_Ihwo24i = dnnl_Bcda24b, + dnnl_IhwO24i2o = dnnl_BcdA24b2a, + dnnl_IhwO24i4o = dnnl_BcdA24b4a, + dnnl_Ohwi24o = dnnl_Acdb24a, + dnnl_Ohwi32o = dnnl_Acdb32a, + dnnl_Ohwi4o = dnnl_Acdb4a, + dnnl_Ohwi8o = dnnl_Acdb8a, + dnnl_OhwI8o2i = dnnl_AcdB8a2b, + dnnl_OhwI8o4i = dnnl_AcdB8a4b, + dnnl_OIhw16i16o = dnnl_ABcd16b16a, + dnnl_OhwI16i16o = dnnl_AcdB16b16a, + dnnl_OIhw16i32o = dnnl_ABcd16b32a, + dnnl_OhwI16i32o = dnnl_AcdB16b32a, + dnnl_OIhw16i48o = dnnl_ABcd16b48a, + dnnl_OhwI16i48o = dnnl_AcdB16b48a, + dnnl_OIhw16i64o = dnnl_ABcd16b64a, + dnnl_OhwI16i64o = dnnl_AcdB16b64a, + dnnl_OIhw16o16i = dnnl_ABcd16a16b, + dnnl_Oihw16o = dnnl_Abcd16a, + dnnl_OIhw4i8o4i = dnnl_ABcd4b8a4b, + dnnl_OhwI4i8o4i = dnnl_AcdB4b8a4b, + dnnl_OIhw4i16o4i = dnnl_ABcd4b16a4b, + dnnl_OhwI4i16o4i = dnnl_AcdB4b16a4b, + dnnl_OIhw4i24o4i = dnnl_ABcd4b24a4b, + dnnl_OhwI4i24o4i = dnnl_AcdB4b24a4b, + dnnl_OIhw4i32o4i = dnnl_ABcd4b32a4b, + dnnl_OhwI4i32o4i = dnnl_AcdB4b32a4b, + dnnl_OIhw4i64o4i = dnnl_ABcd4b64a4b, + dnnl_OhwI4i64o4i = dnnl_AcdB4b64a4b, + dnnl_OIhw16i16o4i = dnnl_ABcd16b16a4b, + dnnl_OIhw16i16o2i = dnnl_ABcd16b16a2b, + dnnl_OIhw16o16i2o = dnnl_ABcd16a16b2a, + dnnl_OIhw4i4o = dnnl_ABcd4b4a, + dnnl_OIhw4o4i = dnnl_ABcd4a4b, + dnnl_Oihw4o = dnnl_Abcd4a, + dnnl_OIhw8i8o2i = dnnl_ABcd8b8a2b, + dnnl_OhwI8i8o2i = dnnl_AcdB8b8a2b, + dnnl_OIhw8i16o2i = dnnl_ABcd8b16a2b, + dnnl_OhwI8i16o2i = dnnl_AcdB8b16a2b, + dnnl_OIhw8i32o2i = dnnl_ABcd8b32a2b, + dnnl_OhwI8i32o2i = dnnl_AcdB8b32a2b, + dnnl_OIhw8i24o2i = dnnl_ABcd8b24a2b, + dnnl_OhwI8i24o2i = dnnl_AcdB8b24a2b, + dnnl_OIhw8i64o2i = dnnl_ABcd8b64a2b, + dnnl_OhwI8i64o2i = dnnl_AcdB8b64a2b, + dnnl_OIhw8i8o = dnnl_ABcd8b8a, + dnnl_OhwI8i8o = dnnl_AcdB8b8a, + dnnl_OIhw8o16i2o = dnnl_ABcd8a16b2a, + dnnl_OIhw2i8o4i = dnnl_ABcd2b8a4b, + dnnl_IOhw8o16i2o = dnnl_BAcd8a16b2a, + dnnl_OIhw8o8i = dnnl_ABcd8a8b, + dnnl_OIhw8o4i = dnnl_ABcd8a4b, + dnnl_Owhi16o = dnnl_Adcb16a, + dnnl_OIhw8i32o = dnnl_ABcd8b32a, + dnnl_OhwI8i32o = dnnl_AcdB8b32a, + dnnl_OIhw8i24o = dnnl_ABcd8b24a, + dnnl_OhwI8i24o = dnnl_AcdB8b24a, + dnnl_OIhw8i16o = dnnl_ABcd8b16a, + dnnl_OhwI8i16o = dnnl_AcdB8b16a, + + // weights, 5D + dnnl_Odhwi16o = dnnl_Acdeb16a, + dnnl_OdhwI16o2i = dnnl_AcdeB16a2b, + dnnl_OdhwI16o4i = dnnl_AcdeB16a4b, + dnnl_Idhwo8i = dnnl_Bcdea8b, + dnnl_IdhwO8i2o = dnnl_BcdeA8b2a, + dnnl_IdhwO8i4o = dnnl_BcdeA8b4a, + dnnl_Idhwo16i = dnnl_Bcdea16b, + dnnl_IdhwO16i2o = dnnl_BcdeA16b2a, + dnnl_IdhwO16i4o = dnnl_BcdeA16b4a, + dnnl_Idhwo24i = dnnl_Bcdea24b, + dnnl_IdhwO24i2o = dnnl_BcdeA24b2a, + dnnl_IdhwO24i4o = dnnl_BcdeA24b4a, + dnnl_Odhwi4o = dnnl_Acdeb4a, + dnnl_Odhwi8o = dnnl_Acdeb8a, + dnnl_OdhwI8o2i = dnnl_AcdeB8a2b, + dnnl_OdhwI8o4i = dnnl_AcdeB8a4b, + dnnl_Odwhi16o = dnnl_Acedb16a, + dnnl_OIdhw16i16o = dnnl_ABcde16b16a, + dnnl_OdhwI16i16o = dnnl_AcdeB16b16a, + dnnl_OIdhw16i32o = dnnl_ABcde16b32a, + dnnl_OdhwI16i32o = dnnl_AcdeB16b32a, + dnnl_OIdhw16i48o = dnnl_ABcde16b48a, + dnnl_OdhwI16i48o = dnnl_AcdeB16b48a, + dnnl_OIdhw16i64o = dnnl_ABcde16b64a, + dnnl_OdhwI16i64o = dnnl_AcdeB16b64a, + dnnl_OIdhw16o16i = dnnl_ABcde16a16b, + dnnl_Oidhw16o = dnnl_Abcde16a, + dnnl_OIdhw4i4o = dnnl_ABcde4b4a, + dnnl_OIdhw4o4i = dnnl_ABcde4a4b, + dnnl_Oidhw4o = dnnl_Abcde4a, + dnnl_OIdhw8i8o2i = dnnl_ABcde8b8a2b, + dnnl_OdhwI8i8o2i = dnnl_AcdeB8b8a2b, + dnnl_OIdhw8i16o2i = dnnl_ABcde8b16a2b, + dnnl_OdhwI8i16o2i = dnnl_AcdeB8b16a2b, + dnnl_OIdhw8i32o2i = dnnl_ABcde8b32a2b, + dnnl_OdhwI8i32o2i = dnnl_AcdeB8b32a2b, + dnnl_OIdhw8i24o2i = dnnl_ABcde8b24a2b, + dnnl_OdhwI8i24o2i = dnnl_AcdeB8b24a2b, + dnnl_OIdhw8i64o2i = dnnl_ABcde8b64a2b, + dnnl_OdhwI8i64o2i = dnnl_AcdeB8b64a2b, + dnnl_OIdhw8i8o = dnnl_ABcde8b8a, + dnnl_OdhwI8i8o = dnnl_AcdeB8b8a, + dnnl_OIdhw8o16i2o = dnnl_ABcde8a16b2a, + dnnl_IOdhw8o16i2o = dnnl_BAcde8a16b2a, + dnnl_OIdhw4i8o4i = dnnl_ABcde4b8a4b, + dnnl_OdhwI4i8o4i = dnnl_AcdeB4b8a4b, + dnnl_OIdhw4i16o4i = dnnl_ABcde4b16a4b, + dnnl_OdhwI4i16o4i = dnnl_AcdeB4b16a4b, + dnnl_OIdhw4i24o4i = dnnl_ABcde4b24a4b, + dnnl_OdhwI4i24o4i = dnnl_AcdeB4b24a4b, + dnnl_OIdhw4i32o4i = dnnl_ABcde4b32a4b, + dnnl_OdhwI4i32o4i = dnnl_AcdeB4b32a4b, + dnnl_OIdhw4i64o4i = dnnl_ABcde4b64a4b, + dnnl_OdhwI4i64o4i = dnnl_AcdeB4b64a4b, + dnnl_OIdhw16i16o4i = dnnl_ABcde16b16a4b, + dnnl_OIdhw16i16o2i = dnnl_ABcde16b16a2b, + dnnl_OIdhw2i8o4i = dnnl_ABcde2b8a4b, + dnnl_OIdhw8o8i = dnnl_ABcde8a8b, + dnnl_OIdhw8o4i = dnnl_ABcde8a4b, + dnnl_IOdhw16i16o = dnnl_BAcde16b16a, + dnnl_OIdhw4o8i8o4i = dnnl_ABcde4a8b8a4b, + dnnl_IOdhw8o8i = dnnl_BAcde8a8b, + dnnl_IOdhw16o16i = dnnl_BAcde16a16b, + dnnl_OIdhw16o16i2o = dnnl_ABcde16a16b2a, + dnnl_OIdhw8i32o = dnnl_ABcde8b32a, + dnnl_OdhwI8i32o = dnnl_AcdeB8b32a, + dnnl_OIdhw8i24o = dnnl_ABcde8b24a, + dnnl_OdhwI8i24o = dnnl_AcdeB8b24a, + dnnl_OIdhw8i16o = dnnl_ABcde8b16a, + dnnl_OdhwI8i16o = dnnl_AcdeB8b16a, + + // weights w/ groups, 3D + dnnl_Goiw16g = dnnl_Abcd16a, + dnnl_Goiw8g = dnnl_Abcd8a, + dnnl_Goiw4g = dnnl_Abcd4a, + dnnl_gIOw8o8i = dnnl_aCBd8b8c, + dnnl_gIOw16o16i = dnnl_aCBd16b16c, + dnnl_gIOw16i16o = dnnl_aCBd16c16b, + dnnl_gOIw16i16o = dnnl_aBCd16c16b, + dnnl_gOIw16o16i = dnnl_aBCd16b16c, + dnnl_gOiw16o = dnnl_aBcd16b, + dnnl_gOIw4i16o4i = dnnl_aBCd4c16b4c, + dnnl_gOIw2i8o4i = dnnl_aBCd2c8b4c, + dnnl_gOIw16i16o4i = dnnl_aBCd16c16b4c, + dnnl_gOIw16i16o2i = dnnl_aBCd16c16b2c, + dnnl_gOIw16o16i2o = dnnl_aBCd16b16c2b, + dnnl_gOIw4i4o = dnnl_aBCd4c4b, + dnnl_gOIw4o4i = dnnl_aBCd4b4c, + dnnl_gOiw4o = dnnl_aBcd4b, + dnnl_gOIw8i16o2i = dnnl_aBCd8c16b2c, + dnnl_gOIw8i8o = dnnl_aBCd8c8b, + dnnl_gOIw8o16i2o = dnnl_aBCd8b16c2b, + dnnl_gIOw8o16i2o = dnnl_aCBd8b16c2b, + dnnl_gOIw8o8i = dnnl_aBCd8b8c, + dnnl_gOIw8o4i = dnnl_aBCd8b4c, + dnnl_gOwi16o = dnnl_aBdc16b, + dnnl_gOwI16o2i = dnnl_aBdC16b2c, + dnnl_gOwI16o4i = dnnl_aBdC16b4c, + dnnl_gIwo8i = dnnl_aCdb8c, + dnnl_gIwO8i2o = dnnl_aCdB8c2b, + dnnl_gIwO8i4o = dnnl_aCdB8c4b, + dnnl_gIwo16i = dnnl_aCdb16c, + dnnl_gIwO16i2o = dnnl_aCdB16c2b, + dnnl_gIwO16i4o = dnnl_aCdB16c4b, + dnnl_gIwo24i = dnnl_aCdb24c, + dnnl_gIwO24i2o = dnnl_aCdB24c2b, + dnnl_gIwO24i4o = dnnl_aCdB24c4b, + dnnl_gOwi4o = dnnl_aBdc4b, + dnnl_gOwi8o = dnnl_aBdc8b, + dnnl_gOwI8o2i = dnnl_aBdC8b2c, + dnnl_gOwI8o4i = dnnl_aBdC8b4c, + dnnl_Goiw32g = dnnl_Abcd32a, + dnnl_gOIw2i4o2i = dnnl_aBCd2c4b2c, + dnnl_gOIw2o4i2o = dnnl_aBCd2b4c2b, + dnnl_gOIw4i8o2i = dnnl_aBCd4c8b2c, + dnnl_gOIw4o8i2o = dnnl_aBCd4b8c2b, + dnnl_goIw4i = dnnl_abCd4c, + dnnl_goIw32i = dnnl_abCd32c, + + // weights w/ groups, 4D + dnnl_gIOhw16i16o = dnnl_aCBde16c16b, + dnnl_gIOhw8o8i = dnnl_aCBde8b8c, + dnnl_gIOhw16o16i = dnnl_aCBde16b16c, + dnnl_gOhwi16o = dnnl_aBdec16b, + dnnl_gOhwI16o2i = dnnl_aBdeC16b2c, + dnnl_gOhwI16o4i = dnnl_aBdeC16b4c, + dnnl_gIhwo8i = dnnl_aCdeb8c, + dnnl_gIhwO8i2o = dnnl_aCdeB8c2b, + dnnl_gIhwO8i4o = dnnl_aCdeB8c4b, + dnnl_gIhwo16i = dnnl_aCdeb16c, + dnnl_gIhwO16i2o = dnnl_aCdeB16c2b, + dnnl_gIhwO16i4o = dnnl_aCdeB16c4b, + dnnl_gIhwo24i = dnnl_aCdeb24c, + dnnl_gIhwO24i2o = dnnl_aCdeB24c2b, + dnnl_gIhwO24i4o = dnnl_aCdeB24c4b, + dnnl_gOhwi32o = dnnl_aBdec32b, + dnnl_gOhwi24o = dnnl_aBdec24b, + dnnl_gOhwI24o2i = dnnl_aBdeC24b2c, + dnnl_gOhwI24o4i = dnnl_aBdeC24b4c, + dnnl_gOhwi4o = dnnl_aBdec4b, + dnnl_gOhwi8o = dnnl_aBdec8b, + dnnl_gOhwI8o2i = dnnl_aBdeC8b2c, + dnnl_gOhwI8o4i = dnnl_aBdeC8b4c, + dnnl_Goihw16g = dnnl_Abcde16a, + dnnl_gOIhw16i16o = dnnl_aBCde16c16b, + dnnl_gOIhw16o16i = dnnl_aBCde16b16c, + dnnl_gOihw16o = dnnl_aBcde16b, + dnnl_gOIhw2i8o4i = dnnl_aBCde2c8b4c, + dnnl_gOIhw4i16o4i = dnnl_aBCde4c16b4c, + dnnl_gOIhw16i16o4i = dnnl_aBCde16c16b4c, + dnnl_gOIhw16i16o2i = dnnl_aBCde16c16b2c, + dnnl_gOIhw16o16i2o = dnnl_aBCde16b16c2b, + dnnl_gOIhw4i4o = dnnl_aBCde4c4b, + dnnl_gOIhw4o4i = dnnl_aBCde4b4c, + dnnl_gOihw4o = dnnl_aBcde4b, + dnnl_Goihw8g = dnnl_Abcde8a, + dnnl_Goihw4g = dnnl_Abcde4a, + dnnl_gOIhw8i16o2i = dnnl_aBCde8c16b2c, + dnnl_gOIhw8i8o = dnnl_aBCde8c8b, + dnnl_gOIhw8o16i2o = dnnl_aBCde8b16c2b, + dnnl_gIOhw8o16i2o = dnnl_aCBde8b16c2b, + dnnl_gOIhw8o8i = dnnl_aBCde8b8c, + dnnl_gOIhw8o4i = dnnl_aBCde8b4c, + dnnl_Goihw32g = dnnl_Abcde32a, + dnnl_gOwhi16o = dnnl_aBedc16b, + dnnl_goIhw4i = dnnl_abCde4c, + dnnl_goIhw32i = dnnl_abCde32c, + + dnnl_OIw4o8i8o4i = dnnl_ABc4a8b8a4b, + dnnl_OIhw4o8i8o4i = dnnl_ABcd4a8b8a4b, + dnnl_IOw4i8o8i4o = dnnl_BAc4b8a8b4a, + dnnl_IOhw4i8o8i4o = dnnl_BAcd4b8a8b4a, + dnnl_IOdhw4i8o8i4o = dnnl_BAcde4b8a8b4a, + + dnnl_OIhw2o8i8o2i = dnnl_ABcd2a8b8a2b, + dnnl_gOIw4o8i8o4i = dnnl_aBCd4b8c8b4c, + dnnl_gOIhw4o8i8o4i = dnnl_aBCde4b8c8b4c, + dnnl_gOIdhw4o8i8o4i = dnnl_aBCdef4b8c8b4c, + dnnl_gIOw4i8o8i4o = dnnl_aCBd4c8b8c4b, + dnnl_gIOhw4i8o8i4o = dnnl_aCBde4c8b8c4b, + dnnl_gIOdhw4i8o8i4o = dnnl_aCBdef4c8b8c4b, + dnnl_gOIhw2o8i8o2i = dnnl_aBCde2b8c8b2c, + dnnl_gOIhw2i4o2i = dnnl_aBCde2c4b2c, + dnnl_gOIhw2o4i2o = dnnl_aBCde2b4c2b, + dnnl_gOIhw4i8o2i = dnnl_aBCde4c8b2c, + dnnl_gOIhw4o8i2o = dnnl_aBCde4b8c2b, + + // weights w/ groups, 6D + dnnl_gIOdhw16i16o = dnnl_aCBdef16c16b, + dnnl_gIOdhw8o8i = dnnl_aCBdef8b8c, + dnnl_gIOdhw16o16i = dnnl_aCBdef16b16c, + dnnl_gOdhwi16o = dnnl_aBdefc16b, + dnnl_gOdhwI16o2i = dnnl_aBdefC16b2c, + dnnl_gOdhwI16o4i = dnnl_aBdefC16b4c, + dnnl_gIdhwo8i = dnnl_aCdefb8c, + dnnl_gIdhwO8i2o = dnnl_aCdefB8c2b, + dnnl_gIdhwO8i4o = dnnl_aCdefB8c4b, + dnnl_gIdhwo16i = dnnl_aCdefb16c, + dnnl_gIdhwO16i2o = dnnl_aCdefB16c2b, + dnnl_gIdhwO16i4o = dnnl_aCdefB16c4b, + dnnl_gIdhwo24i = dnnl_aCdefb24c, + dnnl_gIdhwO24i2o = dnnl_aCdefB24c2b, + dnnl_gIdhwO24i4o = dnnl_aCdefB24c4b, + dnnl_gOdhwi4o = dnnl_aBdefc4b, + dnnl_gOdhwi8o = dnnl_aBdefc8b, + dnnl_gOdhwI8o2i = dnnl_aBdefC8b2c, + dnnl_gOdhwI8o4i = dnnl_aBdefC8b4c, + dnnl_gOdwhi16o = dnnl_aBdfec16b, + dnnl_gOIdhw16i16o = dnnl_aBCdef16c16b, + dnnl_gOIdhw4i16o4i = dnnl_aBCdef4c16b4c, + dnnl_gOIdhw16i16o4i = dnnl_aBCdef16c16b4c, + dnnl_gOIdhw2i8o4i = dnnl_aBCdef2c8b4c, + dnnl_gOIdhw16i16o2i = dnnl_aBCdef16c16b2c, + dnnl_gOIdhw16o16i = dnnl_aBCdef16b16c, + dnnl_gOIdhw16o16i2o = dnnl_aBCdef16b16c2b, + dnnl_gOidhw16o = dnnl_aBcdef16b, + dnnl_gOIdhw4i4o = dnnl_aBCdef4c4b, + dnnl_gOIdhw4o4i = dnnl_aBCdef4b4c, + dnnl_gOidhw4o = dnnl_aBcdef4b, + dnnl_gOIdhw8i16o2i = dnnl_aBCdef8c16b2c, + dnnl_gOIdhw8i8o = dnnl_aBCdef8c8b, + dnnl_gOIdhw8o16i2o = dnnl_aBCdef8b16c2b, + dnnl_gIOdhw8o16i2o = dnnl_aCBdef8b16c2b, + dnnl_gOIdhw8o8i = dnnl_aBCdef8b8c, + dnnl_gOIdhw8o4i = dnnl_aBCdef8b4c, + dnnl_Goidhw16g = dnnl_Abcdef16a, + dnnl_Goidhw32g = dnnl_Abcdef32a, + dnnl_gOIdhw2i4o2i = dnnl_aBCdef2c4b2c, + dnnl_gOIdhw4i8o2i = dnnl_aBCdef4c8b2c, + dnnl_gOIdhw2o4i2o = dnnl_aBCdef2b4c2b, + dnnl_gOIdhw4o8i2o = dnnl_aBCdef4b8c2b, + dnnl_goIdhw4i = dnnl_abCdef4c, + dnnl_goIdhw32i = dnnl_abCdef32c, + + // weights, 3D + dnnl_Owi24o = dnnl_Acb24a, + dnnl_OwI24o2i = dnnl_AcB24a2b, + dnnl_OwI24o4i = dnnl_AcB24a4b, + dnnl_Owi32o = dnnl_Acb32a, + dnnl_OwI32o2i = dnnl_AcB32a2b, + dnnl_OwI32o4i = dnnl_AcB32a4b, + dnnl_Owi48o = dnnl_Acb48a, + dnnl_OwI48o2i = dnnl_AcB48a2b, + dnnl_OwI48o4i = dnnl_AcB48a4b, + dnnl_Owi64o = dnnl_Acb64a, + dnnl_OwI64o2i = dnnl_AcB64a2b, + dnnl_OwI64o4i = dnnl_AcB64a4b, + dnnl_Iwo32i = dnnl_Bca32b, + dnnl_IwO32i2o = dnnl_BcA32b2a, + dnnl_IwO32i4o = dnnl_BcA32b4a, + dnnl_Iwo48i = dnnl_Bca48b, + dnnl_IwO48i2o = dnnl_BcA48b2a, + dnnl_IwO48i4o = dnnl_BcA48b4a, + dnnl_Iwo64i = dnnl_Bca64b, + dnnl_IwO64i2o = dnnl_BcA64b2a, + dnnl_IwO64i4o = dnnl_BcA64b4a, + dnnl_wIo2i = dnnl_cBa2b, + dnnl_wIo4i = dnnl_cBa4b, + dnnl_gOwi24o = dnnl_aBdc24b, + dnnl_gOwI24o2i = dnnl_aBdC24b2c, + dnnl_gOwI24o4i = dnnl_aBdC24b4c, + dnnl_gOwi32o = dnnl_aBdc32b, + dnnl_gOwI32o2i = dnnl_aBdC32b2c, + dnnl_gOwI32o4i = dnnl_aBdC32b4c, + dnnl_gOwi48o = dnnl_aBdc48b, + dnnl_gOwI48o2i = dnnl_aBdC48b2c, + dnnl_gOwI48o4i = dnnl_aBdC48b4c, + dnnl_gOwi64o = dnnl_aBdc64b, + dnnl_gOwI64o2i = dnnl_aBdC64b2c, + dnnl_gOwI64o4i = dnnl_aBdC64b4c, + dnnl_gIwo32i = dnnl_aCdb32c, + dnnl_gIwO32i2o = dnnl_aCdB32c2b, + dnnl_gIwO32i4o = dnnl_aCdB32c4b, + dnnl_gIwo48i = dnnl_aCdb48c, + dnnl_gIwO48i2o = dnnl_aCdB48c2b, + dnnl_gIwO48i4o = dnnl_aCdB48c4b, + dnnl_gIwo64i = dnnl_aCdb64c, + dnnl_gIwO64i2o = dnnl_aCdB64c2b, + dnnl_gIwO64i4o = dnnl_aCdB64c4b, + dnnl_gwio = dnnl_adcb, + dnnl_gwIo2i = dnnl_adCb2c, + dnnl_gwIo4i = dnnl_adCb4c, + // weights, 4D + dnnl_OhwI24o = dnnl_Acdb24a, + dnnl_OhwI24o2i = dnnl_AcdB24a2b, + dnnl_OhwI24o4i = dnnl_AcdB24a4b, + dnnl_OhwI32o = dnnl_Acdb32a, + dnnl_OhwI32o2i = dnnl_AcdB32a2b, + dnnl_OhwI32o4i = dnnl_AcdB32a4b, + dnnl_Ohwi48o = dnnl_Acdb48a, + dnnl_OhwI48o2i = dnnl_AcdB48a2b, + dnnl_OhwI48o4i = dnnl_AcdB48a4b, + dnnl_Ohwi64o = dnnl_Acdb64a, + dnnl_OhwI64o2i = dnnl_AcdB64a2b, + dnnl_OhwI64o4i = dnnl_AcdB64a4b, + dnnl_Ihwo32i = dnnl_Bcda32b, + dnnl_IhwO32i2o = dnnl_BcdA32b2a, + dnnl_IhwO32i4o = dnnl_BcdA32b4a, + dnnl_Ihwo48i = dnnl_Bcda48b, + dnnl_IhwO48i2o = dnnl_BcdA48b2a, + dnnl_IhwO48i4o = dnnl_BcdA48b4a, + dnnl_Ihwo64i = dnnl_Bcda64b, + dnnl_IhwO64i2o = dnnl_BcdA64b2a, + dnnl_IhwO64i4o = dnnl_BcdA64b4a, + dnnl_hwIo2i = dnnl_cdBa2b, + dnnl_hwIo4i = dnnl_cdBa4b, + dnnl_gOhwI24o = dnnl_aBdec24b, + dnnl_gOhwI32o = dnnl_aBdec32b, + dnnl_gOhwI32o2i = dnnl_aBdeC32b2c, + dnnl_gOhwI32o4i = dnnl_aBdeC32b4c, + dnnl_gOhwi48o = dnnl_aBdec48b, + dnnl_gOhwI48o2i = dnnl_aBdeC48b2c, + dnnl_gOhwI48o4i = dnnl_aBdeC48b4c, + dnnl_gOhwi64o = dnnl_aBdec64b, + dnnl_gOhwI64o2i = dnnl_aBdeC64b2c, + dnnl_gOhwI64o4i = dnnl_aBdeC64b4c, + dnnl_gIhwo32i = dnnl_aCdeb32c, + dnnl_gIhwO32i2o = dnnl_aCdeB32c2b, + dnnl_gIhwO32i4o = dnnl_aCdeB32c4b, + dnnl_gIhwo48i = dnnl_aCdeb48c, + dnnl_gIhwO48i2o = dnnl_aCdeB48c2b, + dnnl_gIhwO48i4o = dnnl_aCdeB48c4b, + dnnl_gIhwo64i = dnnl_aCdeb64c, + dnnl_gIhwO64i2o = dnnl_aCdeB64c2b, + dnnl_gIhwO64i4o = dnnl_aCdeB64c4b, + dnnl_ghwio = dnnl_adecb, + dnnl_ghwIo2i = dnnl_adeCb2c, + dnnl_ghwIo4i = dnnl_adeCb4c, + // weights, 5D + dnnl_Odhwi24o = dnnl_Acdeb24a, + dnnl_OdhwI24o2i = dnnl_AcdeB24a2b, + dnnl_OdhwI24o4i = dnnl_AcdeB24a4b, + dnnl_Odhwi32o = dnnl_Acdeb32a, + dnnl_OdhwI32o2i = dnnl_AcdeB32a2b, + dnnl_OdhwI32o4i = dnnl_AcdeB32a4b, + dnnl_Odhwi48o = dnnl_Acdeb48a, + dnnl_OdhwI48o2i = dnnl_AcdeB48a2b, + dnnl_OdhwI48o4i = dnnl_AcdeB48a4b, + dnnl_Odhwi64o = dnnl_Acdeb64a, + dnnl_OdhwI64o2i = dnnl_AcdeB64a2b, + dnnl_OdhwI64o4i = dnnl_AcdeB64a4b, + dnnl_Idhwo32i = dnnl_Bcdea32b, + dnnl_IdhwO32i2o = dnnl_BcdeA32b2a, + dnnl_IdhwO32i4o = dnnl_BcdeA32b4a, + dnnl_Idhwo48i = dnnl_Bcdea48b, + dnnl_IdhwO48i2o = dnnl_BcdeA48b2a, + dnnl_IdhwO48i4o = dnnl_BcdeA48b4a, + dnnl_Idhwo64i = dnnl_Bcdea64b, + dnnl_IdhwO64i2o = dnnl_BcdeA64b2a, + dnnl_IdhwO64i4o = dnnl_BcdeA64b4a, + dnnl_dhwIo2i = dnnl_cdeBa2b, + dnnl_dhwIo4i = dnnl_cdeBa4b, + dnnl_gOdhwi24o = dnnl_aBdefc24b, + dnnl_gOdhwI24o2i = dnnl_aBdefC24b2c, + dnnl_gOdhwI24o4i = dnnl_aBdefC24b4c, + dnnl_gOdhwi32o = dnnl_aBdefc32b, + dnnl_gOdhwI32o2i = dnnl_aBdefC32b2c, + dnnl_gOdhwI32o4i = dnnl_aBdefC32b4c, + dnnl_gOdhwi48o = dnnl_aBdefc48b, + dnnl_gOdhwI48o2i = dnnl_aBdefC48b2c, + dnnl_gOdhwI48o4i = dnnl_aBdefC48b4c, + dnnl_gOdhwi64o = dnnl_aBdefc64b, + dnnl_gOdhwI64o2i = dnnl_aBdefC64b2c, + dnnl_gOdhwI64o4i = dnnl_aBdefC64b4c, + dnnl_gIdhwo32i = dnnl_aCdefb32c, + dnnl_gIdhwO32i2o = dnnl_aCdefB32c2b, + dnnl_gIdhwO32i4o = dnnl_aCdefB32c4b, + dnnl_gIdhwo48i = dnnl_aCdefb48c, + dnnl_gIdhwO48i2o = dnnl_aCdefB48c2b, + dnnl_gIdhwO48i4o = dnnl_aCdefB48c4b, + dnnl_gIdhwo64i = dnnl_aCdefb64c, + dnnl_gIdhwO64i2o = dnnl_aCdefB64c2b, + dnnl_gIdhwO64i4o = dnnl_aCdefB64c4b, + dnnl_gdhwio = dnnl_adefcb, + dnnl_gdhwIo2i = dnnl_adefCb2c, + dnnl_gdhwIo4i = dnnl_adefCb4c, + dnnl_OI16i32o4i = dnnl_AB16b32a4b, + dnnl_OI16i48o4i = dnnl_AB16b48a4b, + dnnl_OI16i64o4i = dnnl_AB16b64a4b, + dnnl_OI16i16o2i = dnnl_AB16b16a2b, + dnnl_OI16i32o2i = dnnl_AB16b32a2b, + dnnl_OI16i48o2i = dnnl_AB16b48a2b, + dnnl_OI16i64o2i = dnnl_AB16b64a2b, + dnnl_OIw16i32o4i = dnnl_ABc16b32a4b, + dnnl_OIw16i48o4i = dnnl_ABc16b48a4b, + dnnl_OIw16i64o4i = dnnl_ABc16b64a4b, + dnnl_OIw16i32o2i = dnnl_ABc16b32a2b, + dnnl_OIw16i48o2i = dnnl_ABc16b48a2b, + dnnl_OIw16i64o2i = dnnl_ABc16b64a2b, + dnnl_OIhw16i32o4i = dnnl_ABcd16b32a4b, + dnnl_OIhw16i48o4i = dnnl_ABcd16b48a4b, + dnnl_OIhw16i64o4i = dnnl_ABcd16b64a4b, + dnnl_OIhw16i32o2i = dnnl_ABcd16b32a2b, + dnnl_OIhw16i48o2i = dnnl_ABcd16b48a2b, + dnnl_OIhw16i64o2i = dnnl_ABcd16b64a2b, + dnnl_OIdhw16i32o4i = dnnl_ABcde16b32a4b, + dnnl_OIdhw16i48o4i = dnnl_ABcde16b48a4b, + dnnl_OIdhw16i64o4i = dnnl_ABcde16b64a4b, + dnnl_OIdhw16i32o2i = dnnl_ABcde16b32a2b, + dnnl_OIdhw16i48o2i = dnnl_ABcde16b48a2b, + dnnl_OIdhw16i64o2i = dnnl_ABcde16b64a2b, + dnnl_OwI16i16o2i = dnnl_AcB16b16a2b, + dnnl_OwI16i16o4i = dnnl_AcB16b16a4b, + dnnl_OhwI16i16o2i = dnnl_AcdB16b16a2b, + dnnl_OhwI16i16o4i = dnnl_AcdB16b16a4b, + dnnl_OdhwI16i16o2i = dnnl_AcdeB16b16a2b, + dnnl_OdhwI16i16o4i = dnnl_AcdeB16b16a4b, + dnnl_IwO16o16i2o = dnnl_BcA16a16b2a, + dnnl_IwO16o16i4o = dnnl_BcA16a16b4a, + dnnl_IhwO16o16i2o = dnnl_BcdA16a16b2a, + dnnl_IhwO16o16i4o = dnnl_BcdA16a16b4a, + dnnl_IdhwO16o16i2o = dnnl_BcdeA16a16b2a, + dnnl_IdhwO16o16i4o = dnnl_BcdeA16a16b4a, + dnnl_gOwI16i16o2i = dnnl_aBdC16c16b2c, + dnnl_gOwI16i16o4i = dnnl_aBdC16c16b4c, + dnnl_gOhwI16i16o2i = dnnl_aBdeC16c16b2c, + dnnl_gOhwI16i16o4i = dnnl_aBdeC16c16b4c, + dnnl_gOdhwI16i16o2i = dnnl_aBdefC16c16b2c, + dnnl_gOdhwI16i16o4i = dnnl_aBdefC16c16b4c, + dnnl_gIwO16o16i2o = dnnl_aCdB16b16c2b, + dnnl_gIwO16o16i4o = dnnl_aCdB16b16c4b, + dnnl_gIhwO16o16i2o = dnnl_aCdeB16b16c2b, + dnnl_gIhwO16o16i4o = dnnl_aCdeB16b16c4b, + dnnl_gIdhwO16o16i2o = dnnl_aCdefB16b16c2b, + dnnl_gIdhwO16o16i4o = dnnl_aCdefB16b16c4b, + dnnl_OwI16i32o2i = dnnl_AcB16b32a2b, + dnnl_OwI16i32o4i = dnnl_AcB16b32a4b, + dnnl_OwI16i48o2i = dnnl_AcB16b48a2b, + dnnl_OwI16i48o4i = dnnl_AcB16b48a4b, + dnnl_OwI16i64o2i = dnnl_AcB16b64a2b, + dnnl_OwI16i64o4i = dnnl_AcB16b64a4b, + dnnl_IwO16o32i2o = dnnl_BcA16a32b2a, + dnnl_IwO16o32i4o = dnnl_BcA16a32b4a, + dnnl_IwO16o48i2o = dnnl_BcA16a48b2a, + dnnl_IwO16o48i4o = dnnl_BcA16a48b4a, + dnnl_IwO16o64i2o = dnnl_BcA16a64b2a, + dnnl_IwO16o64i4o = dnnl_BcA16a64b4a, + dnnl_gOwI16i32o2i = dnnl_aBdC16c32b2c, + dnnl_gOwI16i32o4i = dnnl_aBdC16c32b4c, + dnnl_gOwI16i48o2i = dnnl_aBdC16c48b2c, + dnnl_gOwI16i48o4i = dnnl_aBdC16c48b4c, + dnnl_gOwI16i64o2i = dnnl_aBdC16c64b2c, + dnnl_gOwI16i64o4i = dnnl_aBdC16c64b4c, + dnnl_gIwO16o32i2o = dnnl_aCdB16b32c2b, + dnnl_gIwO16o32i4o = dnnl_aCdB16b32c4b, + dnnl_gIwO16o48i2o = dnnl_aCdB16b48c2b, + dnnl_gIwO16o48i4o = dnnl_aCdB16b48c4b, + dnnl_gIwO16o64i2o = dnnl_aCdB16b64c2b, + dnnl_gIwO16o64i4o = dnnl_aCdB16b64c4b, + dnnl_OhwI16i32o2i = dnnl_AcdB16b32a2b, + dnnl_OhwI16i32o4i = dnnl_AcdB16b32a4b, + dnnl_OhwI16i48o2i = dnnl_AcdB16b48a2b, + dnnl_OhwI16i48o4i = dnnl_AcdB16b48a4b, + dnnl_OhwI16i64o2i = dnnl_AcdB16b64a2b, + dnnl_OhwI16i64o4i = dnnl_AcdB16b64a4b, + dnnl_IhwO16o32i2o = dnnl_BcdA16a32b2a, + dnnl_IhwO16o32i4o = dnnl_BcdA16a32b4a, + dnnl_IhwO16o48i2o = dnnl_BcdA16a48b2a, + dnnl_IhwO16o48i4o = dnnl_BcdA16a48b4a, + dnnl_IhwO16o64i2o = dnnl_BcdA16a64b2a, + dnnl_IhwO16o64i4o = dnnl_BcdA16a64b4a, + dnnl_gOhwI16i32o2i = dnnl_aBdeC16c32b2c, + dnnl_gOhwI16i32o4i = dnnl_aBdeC16c32b4c, + dnnl_gOhwI16i48o2i = dnnl_aBdeC16c48b2c, + dnnl_gOhwI16i48o4i = dnnl_aBdeC16c48b4c, + dnnl_gOhwI16i64o2i = dnnl_aBdeC16c64b2c, + dnnl_gOhwI16i64o4i = dnnl_aBdeC16c64b4c, + dnnl_gIhwO16o32i2o = dnnl_aCdeB16b32c2b, + dnnl_gIhwO16o32i4o = dnnl_aCdeB16b32c4b, + dnnl_gIhwO16o48i2o = dnnl_aCdeB16b48c2b, + dnnl_gIhwO16o48i4o = dnnl_aCdeB16b48c4b, + dnnl_gIhwO16o64i2o = dnnl_aCdeB16b64c2b, + dnnl_gIhwO16o64i4o = dnnl_aCdeB16b64c4b, + dnnl_OdhwI16i32o2i = dnnl_AcdeB16b32a2b, + dnnl_OdhwI16i32o4i = dnnl_AcdeB16b32a4b, + dnnl_OdhwI16i48o2i = dnnl_AcdeB16b48a2b, + dnnl_OdhwI16i48o4i = dnnl_AcdeB16b48a4b, + dnnl_OdhwI16i64o2i = dnnl_AcdeB16b64a2b, + dnnl_OdhwI16i64o4i = dnnl_AcdeB16b64a4b, + dnnl_IdhwO16o32i2o = dnnl_BcdeA16a32b2a, + dnnl_IdhwO16o32i4o = dnnl_BcdeA16a32b4a, + dnnl_IdhwO16o48i2o = dnnl_BcdeA16a48b2a, + dnnl_IdhwO16o48i4o = dnnl_BcdeA16a48b4a, + dnnl_IdhwO16o64i2o = dnnl_BcdeA16a64b2a, + dnnl_IdhwO16o64i4o = dnnl_BcdeA16a64b4a, + dnnl_gOdhwI16i32o2i = dnnl_aBdefC16c32b2c, + dnnl_gOdhwI16i32o4i = dnnl_aBdefC16c32b4c, + dnnl_gOdhwI16i48o2i = dnnl_aBdefC16c48b2c, + dnnl_gOdhwI16i48o4i = dnnl_aBdefC16c48b4c, + dnnl_gOdhwI16i64o2i = dnnl_aBdefC16c64b2c, + dnnl_gOdhwI16i64o4i = dnnl_aBdefC16c64b4c, + dnnl_gIdhwO16o32i2o = dnnl_aCdefB16b32c2b, + dnnl_gIdhwO16o32i4o = dnnl_aCdefB16b32c4b, + dnnl_gIdhwO16o48i2o = dnnl_aCdefB16b48c2b, + dnnl_gIdhwO16o48i4o = dnnl_aCdefB16b48c4b, + dnnl_gIdhwO16o64i2o = dnnl_aCdefB16b64c2b, + dnnl_gIdhwO16o64i4o = dnnl_aCdefB16b64c4b, + dnnl_hwioG16g = dnnl_decbA16a, + dnnl_hwioG8g = dnnl_decbA8a, + dnnl_dhwioG16g = dnnl_defcbA16a, + dnnl_dhwioG8g = dnnl_defcbA8a, + dnnl_NCdhw40n16c = dnnl_ABcde40a16b, + dnnl_NCw40n16c = dnnl_ABc40a16b, + dnnl_NChw40n16c = dnnl_ABcd40a16b, + dnnl_NCw40n32c = dnnl_ABc40a32b, + dnnl_NChw40n32c = dnnl_ABcd40a32b, + dnnl_NCdhw40n32c = dnnl_ABcde40a32b, + dnnl_OIdhw4o8i8o2i = dnnl_ABcde4a8b8a2b, + dnnl_OIhw4o8i8o2i = dnnl_ABcd4a8b8a2b, + dnnl_OIw4o8i8o2i = dnnl_ABc4a8b8a2b, + dnnl_gOIdhw4o8i8o2i = dnnl_aBCdef4b8c8b2c, + dnnl_gOIhw4o8i8o2i = dnnl_aBCde4b8c8b2c, + dnnl_gOIw4o8i8o2i = dnnl_aBCd4b8c8b2c, + dnnl_IOdhw4i8o8i2o = dnnl_BAcde4b8a8b2a, + dnnl_IOhw4i8o8i2o = dnnl_BAcd4b8a8b2a, + dnnl_IOw4i8o8i2o = dnnl_BAc4b8a8b2a, + dnnl_gIOdhw4i8o8i2o = dnnl_aCBdef4c8b8c2b, + dnnl_gIOhw4i8o8i2o = dnnl_aCBde4c8b8c2b, + dnnl_gIOw4i8o8i2o = dnnl_aCBd4c8b8c2b, + dnnl_NCw2c32n8c = dnnl_ABc2b32a8b, + dnnl_NChw2c32n8c = dnnl_ABcd2b32a8b, + dnnl_NCdhw2c32n8c = dnnl_ABcde2b32a8b, + dnnl_OIw2i8o16i4o = dnnl_ABc2b8a16b4a, + dnnl_OIhw2i8o16i4o = dnnl_ABcd2b8a16b4a, + dnnl_OIdhw2i8o16i4o = dnnl_ABcde2b8a16b4a, + dnnl_OIw2o8i16o4i = dnnl_ABc2a8b16a4b, + dnnl_OIw2o8i16o2i = dnnl_ABc2a8b16a2b, + dnnl_IOw2i8o16i4o = dnnl_BAc2b8a16b4a, + dnnl_IOw2i8o16i2o = dnnl_BAc2b8a16b2a, + dnnl_OIhw2o8i16o4i = dnnl_ABcd2a8b16a4b, + dnnl_OIhw2o8i16o2i = dnnl_ABcd2a8b16a2b, + dnnl_IOhw2i8o16i4o = dnnl_BAcd2b8a16b4a, + dnnl_IOhw2i8o16i2o = dnnl_BAcd2b8a16b2a, + dnnl_OIdhw2o8i16o4i = dnnl_ABcde2a8b16a4b, + dnnl_OIdhw2o8i16o2i = dnnl_ABcde2a8b16a2b, + dnnl_IOdhw2i8o16i4o = dnnl_BAcde2b8a16b4a, + dnnl_IOdhw2i8o16i2o = dnnl_BAcde2b8a16b2a, + dnnl_gOIw2o8i16o2i = dnnl_aBCd2b8c16b2c, + dnnl_gIOw2i8o16i2o = dnnl_aCBd2c8b16c2b, + dnnl_gIOhw2i8o16i2o = dnnl_aBCde2c8b16c2b, + dnnl_gIOdhw2i8o16i2o = dnnl_aBCdef2c8b16c2b, + dnnl_gOIhw2o8i16o2i = dnnl_aBCde2b8c16b2c, + dnnl_gOIdhw2o8i16o2i = dnnl_aBCdef2b8c16b2c, + dnnl_gOIw2o8i16o4i = dnnl_aBCd2b8c16b4c, + dnnl_gOIhw2o8i16o4i = dnnl_aBCde2b8c16b4c, +} dnnl_format_tag_t; + +/// @} dnnl_api_memory + +/// @addtogroup dnnl_api_primitives +/// @{ +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// Kinds of propagation. +typedef enum { + // TODO: suggest renames + /// Undefined propagation type. + dnnl_prop_kind_undef = 0, + /// Forward data propagation (training mode). In this mode primitives + /// perform computations necessary for subsequent backward propagation. + dnnl_forward_training = 64, + /// Forward data propagation (inference mode). In this mode primitives + /// perform only computations that are necessary for inference and omit + /// computations that are necessary only for backward propagation. + dnnl_forward_inference = 96, + /// Forward data propagation (alias for @c dnnl_forward_training). + dnnl_forward = dnnl_forward_training, + /// Backward propagation (with respect to all parameters). + dnnl_backward = 128, + /// Backward data propagation. + dnnl_backward_data = 160, + /// Backward weights propagation. + dnnl_backward_weights = 192, + /// Backward bias propagation. + dnnl_backward_bias = 193, +} dnnl_prop_kind_t; + +/// Kinds of primitives. Used to implement a way to extend the library with new +/// primitives without changing the ABI. +typedef enum { + /// Undefined primitive + dnnl_undefined_primitive, + /// A reorder primitive. + dnnl_reorder, + /// A shuffle primitive. + dnnl_shuffle, + /// A (out-of-place) concat primitive. + dnnl_concat, + /// A sum primitive. + dnnl_sum, + /// A convolution primitive. + dnnl_convolution, + /// A deconvolution primitive. + dnnl_deconvolution, + /// An element-wise primitive. + dnnl_eltwise, + /// An LRN primitive. + dnnl_lrn, + /// A batch normalization primitive. + dnnl_batch_normalization, + /// An inner product primitive. + dnnl_inner_product, + /// A rnn primitive. + dnnl_rnn, + /// A matrix multiplication primitive (internal). + dnnl_gemm, + /// A binary primitive. + dnnl_binary, + /// A matrix multiplication primitive. + dnnl_matmul, + /// A resampling primitive. + dnnl_resampling, + /// A pooling primitive. + dnnl_pooling, + /// A reduction primitive. + dnnl_reduction, + /// A PReLU primitive. + dnnl_prelu, + /// A softmax primitive. + dnnl_softmax, + /// A layer normalization primitive. + dnnl_layer_normalization, + /// A group normalization primitive. + dnnl_group_normalization, + + /// Parameter to allow internal only primitives without undefined behavior. + /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. + dnnl_primitive_kind_max = 0x7fff, +} dnnl_primitive_kind_t; + +/// Kinds of algorithms. +typedef enum { + dnnl_alg_kind_undef, + /// Direct convolution + dnnl_convolution_direct = 0x1, + /// Winograd convolution + dnnl_convolution_winograd = 0x2, + /// Convolution algorithm(either direct or Winograd) is chosen just in time + dnnl_convolution_auto = 0x3, + /// Direct deconvolution + dnnl_deconvolution_direct = 0xa, + /// Winograd deconvolution + dnnl_deconvolution_winograd = 0xb, + /// Eltwise: ReLU + dnnl_eltwise_relu = 0x20, + /// Eltwise: hyperbolic tangent non-linearity (tanh) + dnnl_eltwise_tanh, + /// Eltwise: exponential linear unit (elu) + dnnl_eltwise_elu, + /// Eltwise: square + dnnl_eltwise_square, + /// Eltwise: abs + dnnl_eltwise_abs, + /// Eltwise: square root + dnnl_eltwise_sqrt, + /// Eltwise: linear + dnnl_eltwise_linear, + /// Eltwise: soft_relu + dnnl_eltwise_soft_relu, + /// Eltwise: hardsigmoid + dnnl_eltwise_hardsigmoid, + /// Eltwise: logistic + dnnl_eltwise_logistic, + /// Eltwise: exponent + dnnl_eltwise_exp, + /// Eltwise: gelu + /// + /// @note Tanh approximation formula is used to approximate + /// the cumulative distribution function of a Gaussian here + dnnl_eltwise_gelu_tanh, + /// Eltwise: swish + dnnl_eltwise_swish, + /// Eltwise: natural logarithm + dnnl_eltwise_log, + /// Eltwise: clip + dnnl_eltwise_clip, + /// Eltwise: clip version 2 + dnnl_eltwise_clip_v2, + /// Eltwise: pow + dnnl_eltwise_pow, + /// Eltwise: erf-based gelu + dnnl_eltwise_gelu_erf, + /// Eltwise: round + dnnl_eltwise_round, + /// Eltwise: mish + dnnl_eltwise_mish, + /// Eltwise: hardswish + dnnl_eltwise_hardswish, + /// Eltwise: ReLU (dst for backward) + dnnl_eltwise_relu_use_dst_for_bwd = 0x100, + /// Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward) + dnnl_eltwise_tanh_use_dst_for_bwd, + /// Eltwise: exponential linear unit (elu) (dst for backward) + dnnl_eltwise_elu_use_dst_for_bwd, + /// Eltwise: square root (dst for backward) + dnnl_eltwise_sqrt_use_dst_for_bwd, + /// Eltwise: logistic (dst for backward) + dnnl_eltwise_logistic_use_dst_for_bwd, + /// Eltwise: exp (dst for backward) + dnnl_eltwise_exp_use_dst_for_bwd, + /// Eltwise: clip version 2 (dst for backward) + dnnl_eltwise_clip_v2_use_dst_for_bwd, + /// Max pooling + dnnl_pooling_max = 0x1ff, + /// Average pooling include padding + dnnl_pooling_avg_include_padding = 0x2ff, + /// Average pooling exclude padding + dnnl_pooling_avg_exclude_padding = 0x3ff, + /// Local response normalization (LRN) across multiple channels + dnnl_lrn_across_channels = 0xaff, + /// LRN within a single channel + dnnl_lrn_within_channel = 0xbff, + /// RNN cell + dnnl_vanilla_rnn = 0x1fff, + /// LSTM cell + dnnl_vanilla_lstm = 0x2fff, + /// GRU cell + dnnl_vanilla_gru = 0x3fff, + /// GRU cell with linear before reset + /// + /// Modification of original GRU cell. Differs from #dnnl_vanilla_gru + /// in how the new memory gate is calculated: + /// \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f] + /// Primitive expects 4 biases on input: + /// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ + dnnl_lbr_gru = 0x4fff, + /// AUGRU cell + dnnl_vanilla_augru = 0x5fff, + /// AUGRU cell with linear before reset + dnnl_lbr_augru = 0x6fff, + /// Binary add + dnnl_binary_add = 0x1fff0, + /// Binary mul + dnnl_binary_mul = 0x1fff1, + /// Binary max + dnnl_binary_max = 0x1fff2, + /// Binary min + dnnl_binary_min = 0x1fff3, + /// Binary div + dnnl_binary_div = 0x1fff4, + /// Binary sub + dnnl_binary_sub = 0x1fff5, + /// Binary greater or equal + dnnl_binary_ge = 0x1fff6, + /// Binary greater than + dnnl_binary_gt = 0x1fff7, + /// Binary less or equal + dnnl_binary_le = 0x1fff8, + /// Binary less than + dnnl_binary_lt = 0x1fff9, + /// Binary equal + dnnl_binary_eq = 0x1fffa, + /// Binary not equal + dnnl_binary_ne = 0x1fffb, + /// Binary select + dnnl_binary_select = 0x1fffc, + /// Nearest Neighbor Resampling Method + dnnl_resampling_nearest = 0x2fff0, + /// Linear Resampling Method + dnnl_resampling_linear = 0x2fff1, + /// Reduction using max + dnnl_reduction_max, + /// Reduction using min + dnnl_reduction_min, + /// Reduction using sum + dnnl_reduction_sum, + /// Reduction using mul + dnnl_reduction_mul, + /// Reduction using mean + dnnl_reduction_mean, + /// Reduction using lp norm + dnnl_reduction_norm_lp_max, + /// Reduction using lp norm + dnnl_reduction_norm_lp_sum, + /// Reduction using lp norm without final pth-root + dnnl_reduction_norm_lp_power_p_max, + /// Reduction using lp norm without final pth-root + dnnl_reduction_norm_lp_power_p_sum, + /// Softmax + dnnl_softmax_accurate = 0x30000, + /// Logsoftmax + dnnl_softmax_log, +} dnnl_alg_kind_t; + +/// Flags for normalization primitives. +typedef enum { + /// Use no normalization flags + /// + /// If specified + /// - on forward training propagation mean and variance are computed and + /// stored as output + /// - on backward propagation compute full derivative wrt data + /// - on backward propagation prop_kind == #dnnl_backward_data has the same + /// behavior as prop_kind == #dnnl_backward + dnnl_normalization_flags_none = 0x0U, + + /// Use global statistics + /// + /// If specified + /// - on forward propagation use mean and variance provided by user (input) + /// - on backward propagation reduces the amount of computations, since + /// mean and variance are considered as constants + /// + /// If not specified: + /// - on forward propagation mean and variance are computed and stored as + /// output + /// - on backward propagation compute full derivative wrt data + dnnl_use_global_stats = 0x1U, + + /// Use scale parameter + /// + /// If specified: + /// - on forward propagation use scale for the normalization results + /// - on backward propagation (for prop_kind == #dnnl_backward) compute + /// diff wrt scale (hence one extra output used) + dnnl_use_scale = 0x2U, + + /// Use shift parameter + /// + /// If specified: + /// - on forward propagation use shift (aka bias) for the normalization + /// results + /// - on backward propagation (for prop_kind == #dnnl_backward) compute + /// diff wrt shift (hence one extra output used) + dnnl_use_shift = 0x4U, + + /// Fuse with ReLU + /// + /// The flag implies negative slope being 0. On training this is the only + /// configuration supported. For inference, to use non-zero negative slope + /// consider using @ref dev_guide_attributes_post_ops. + /// + /// If specified: + /// - on inference this option behaves the same as if the primitive were + /// fused with ReLU using post ops API with zero negative slope. + /// - on training primitive requires workspace (required to be able to + /// perform backward pass) + dnnl_fuse_norm_relu = 0x8U, + + /// Fuse with Add and then fuse with ReLU + /// + /// If specified: + /// + /// - on forward propagation apply element-wise binary Add operation to + /// to the normalization results with an additional input tensor and then + /// apply ReLU with negative slope being 0. + /// - on training primitive requires workspace (required to be able to + /// perform backward pass). + /// - on backward propagation save the result of backward ReLU operation + /// with input tensor and workspace from forward pass to extra output + /// tensor and then perform backward normalization. + dnnl_fuse_norm_add_relu = 0x10U, + +} dnnl_normalization_flags_t; + +/// @} dnnl_api_primitives_common +/// @} dnnl_api_primitives + +/// @addtogroup dnnl_api_memory +/// @{ + +/// A wildcard value for dimensions that are unknown at a primitive creation +/// time. +#define DNNL_RUNTIME_DIM_VAL INT64_MIN + +/// A `size_t` counterpart of the DNNL_RUNTIME_DIM_VAL. +/// For instance, this value is returned by dnnl_memory_desc_get_size() if +/// either of the dimensions or strides equal to #DNNL_RUNTIME_DIM_VAL. +#define DNNL_RUNTIME_SIZE_VAL ((size_t)DNNL_RUNTIME_DIM_VAL) + +/// @cond DO_NOT_DOCUMENT_THIS +/// Hex representation for a **special** quiet NAN (!= NAN from math.h) +static const union { + unsigned u; + float f; +} DNNL_RUNTIME_F32_VAL_REP = {0x7fc000d0}; +/// @endcond + +/// A wildcard value for floating point values that are unknown at a primitive +/// creation time. +#define DNNL_RUNTIME_F32_VAL (DNNL_RUNTIME_F32_VAL_REP.f) + +/// @cond DO_NOT_DOCUMENT_THIS +static const int DNNL_RUNTIME_S32_VAL_REP = INT32_MIN; +/// @endcond + +/// A wildcard value for int32_t values that are unknown at a primitive creation +/// time. +#define DNNL_RUNTIME_S32_VAL DNNL_RUNTIME_S32_VAL_REP + +/// @struct dnnl_memory_desc +/// An opaque structure to describe a memory descriptor. +struct dnnl_memory_desc; + +/// A memory descriptor handle. +typedef struct dnnl_memory_desc *dnnl_memory_desc_t; + +/// A memory descriptor handle. +typedef const struct dnnl_memory_desc *const_dnnl_memory_desc_t; + +/// @struct dnnl_memory +/// An opaque structure to describe a memory. +struct dnnl_memory; + +/// A memory handle. +typedef struct dnnl_memory *dnnl_memory_t; + +/// A constant memory handle. +typedef const struct dnnl_memory *const_dnnl_memory_t; + +/// @} dnnl_api_memory + +/// @addtogroup dnnl_api_primitives +/// @{ + +/// @addtogroup dnnl_api_rnn +/// @{ + +/// Flags for RNN cell. +typedef enum { + /// Undefined RNN flags + dnnl_rnn_flags_undef = 0x0, + /// Do not add weights gradient to existing diff_weights memory + dnnl_rnn_flags_diff_weights_overwrite = 0x1, +} dnnl_rnn_flags_t; + +/// A direction of RNN primitive execution. +typedef enum { + /// Undefined RNN direction. + dnnl_rnn_direction_undef = 0, + /// Unidirectional execution of RNN primitive from left to right. + dnnl_unidirectional_left2right, + /// Unidirectional execution of RNN primitive from right to left. + dnnl_unidirectional_right2left, + /// Bidirectional execution of RNN primitive with concatenation of the + /// results. + dnnl_bidirectional_concat, + /// Bidirectional execution of RNN primitive with summation of the + /// results. + dnnl_bidirectional_sum, +} dnnl_rnn_direction_t; + +/// @} dnnl_api_rnn + +/// @} dnnl_api_primitives + +/// @addtogroup dnnl_api_primitives +/// @{ +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// @struct dnnl_primitive_desc +/// @brief An opaque structure to describe a primitive descriptor. +struct dnnl_primitive_desc; + +/// @brief A primitive descriptor handle. +typedef struct dnnl_primitive_desc *dnnl_primitive_desc_t; + +/// @brief A constant primitive descriptor handle. +typedef const struct dnnl_primitive_desc *const_dnnl_primitive_desc_t; + +/// @} dnnl_api_primitives_common + +/// @addtogroup dnnl_api_attributes +/// @{ + +/// Scratchpad mode +typedef enum { + /// The library manages the scratchpad allocation according to the policy + /// specified by the `DNNL_ENABLE_CONCURRENT_EXEC` + /// [build option](@ref dev_guide_build_options) (default). + /// + /// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library + /// scratchpad is common to all primitives to reduce the memory footprint. + /// This configuration comes with limited thread-safety properties, namely + /// primitives can be created and executed in parallel but cannot migrate + /// between threads (in other words, each primitive should be executed in + /// the same thread it was created in). + /// + /// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is + /// private to each primitive. The memory footprint is larger than when + /// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be + /// created and run concurrently (the same primitive cannot be run + /// concurrently from two different threads though). + dnnl_scratchpad_mode_library, + /// The user manages the scratchpad allocation by querying and providing + /// the scratchpad memory to primitives. This mode is thread-safe as long + /// as the scratchpad buffers are not used concurrently by two primitive + /// executions. + dnnl_scratchpad_mode_user, +} dnnl_scratchpad_mode_t; + +/// Rounding mode +typedef enum { + /// rounding mode dictated by the floating-point environment + dnnl_rounding_mode_environment, + /// stochastic rounding mode where a random bias is added to the + /// trailing mantissa bits before conversion. + dnnl_rounding_mode_stochastic, +} dnnl_rounding_mode_t; + +/// @struct dnnl_primitive_attr +/// @brief An opaque structure for primitive descriptor attributes. +/// +/// Attributes may contain: +/// - output scales (to scale the result prior to storing it to the memory) +struct dnnl_primitive_attr; + +/// @brief A primitive descriptor attributes handle that controls primitive +/// behavior. +typedef struct dnnl_primitive_attr *dnnl_primitive_attr_t; + +/// @brief A constant primitive descriptor attributes handle. +typedef const struct dnnl_primitive_attr *const_dnnl_primitive_attr_t; + +/// @struct dnnl_post_ops +/// @brief An opaque structure for a chain of post operations. +/// +/// dnnl_post_ops can be used to perform some (trivial) operations like +/// accumulation or eltwise after certain primitives like convolution. +/// +/// Post operations might be combined together, making a chain of post +/// operations. For instance one can configure convolution followed by +/// accumulation followed by eltwise. This might be especially beneficial +/// for residual learning blocks. +/// +/// @warning +/// Of course not all combinations are supported, so the user should handle +/// errors accordingly. +/// +/// Supported post operations: +/// - accumulation (base primitive: convolution) +/// - eltwise (base primitive: convolution) +struct dnnl_post_ops; + +/// @brief A post operation chain handle. +typedef struct dnnl_post_ops *dnnl_post_ops_t; + +/// @brief A constant post operation chain handle. +typedef const struct dnnl_post_ops *const_dnnl_post_ops_t; + +/// @} dnnl_api_attributes + +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// @struct dnnl_primitive +/// An opaque structure to describe a primitive. +struct dnnl_primitive; +/// A primitive handle. +typedef struct dnnl_primitive *dnnl_primitive_t; +/// A constant primitive handle. +typedef const struct dnnl_primitive *const_dnnl_primitive_t; + +/// Undefined argument. +#define DNNL_ARG_UNDEF 0 +/// Source argument #0. +#define DNNL_ARG_SRC_0 1 +/// A special mnemonic for source argument for primitives that have a +/// single source. An alias for #DNNL_ARG_SRC_0. +#define DNNL_ARG_SRC DNNL_ARG_SRC_0 +/// A special mnemonic for RNN input vector. An alias for +/// #DNNL_ARG_SRC_0. +#define DNNL_ARG_SRC_LAYER DNNL_ARG_SRC_0 +/// A special mnemonic for reorder source argument. An alias for +/// #DNNL_ARG_SRC_0. +#define DNNL_ARG_FROM DNNL_ARG_SRC_0 + +/// Source argument #1. +#define DNNL_ARG_SRC_1 2 +/// A special mnemonic for RNN input recurrent hidden state vector. An alias +/// for #DNNL_ARG_SRC_1. +#define DNNL_ARG_SRC_ITER DNNL_ARG_SRC_1 + +/// Source argument #2. +#define DNNL_ARG_SRC_2 3 +/// A special mnemonic for RNN input recurrent cell state vector. An alias for +/// #DNNL_ARG_SRC_2. +#define DNNL_ARG_SRC_ITER_C DNNL_ARG_SRC_2 + +/// Source argument #3. +#define DNNL_ARG_SRC_3 4 +/// A special mnemonic for RNN input recurrent cell attention vector. An alias for +/// #DNNL_ARG_SRC_3. +#define DNNL_ARG_AUGRU_ATTENTION DNNL_ARG_SRC_3 + +/// Destination argument #0. +#define DNNL_ARG_DST_0 17 +/// A special mnemonic for destination argument for primitives that have a +/// single destination. An alias for #DNNL_ARG_DST_0. +#define DNNL_ARG_DST DNNL_ARG_DST_0 +/// A special mnemonic for reorder destination argument. An alias for +/// #DNNL_ARG_DST_0. +#define DNNL_ARG_TO DNNL_ARG_DST_0 +/// A special mnemonic for RNN output vector. An alias for #DNNL_ARG_DST_0. +#define DNNL_ARG_DST_LAYER DNNL_ARG_DST_0 + +/// Destination argument #1. +#define DNNL_ARG_DST_1 18 +/// A special mnemonic for RNN input recurrent hidden state vector. An +/// alias for #DNNL_ARG_DST_1. +#define DNNL_ARG_DST_ITER DNNL_ARG_DST_1 + +/// Destination argument #2. +#define DNNL_ARG_DST_2 19 +/// A special mnemonic for LSTM output recurrent cell state vector. An +/// alias for #DNNL_ARG_DST_2. +#define DNNL_ARG_DST_ITER_C DNNL_ARG_DST_2 + +/// Weights argument #0. +#define DNNL_ARG_WEIGHTS_0 33 +/// A special mnemonic for primitives that have a single weights +/// argument. Alias for #DNNL_ARG_WEIGHTS_0. +#define DNNL_ARG_WEIGHTS DNNL_ARG_WEIGHTS_0 +/// A special mnemonic for RNN weights applied to the layer input. An +/// alias for #DNNL_ARG_WEIGHTS_0. +#define DNNL_ARG_WEIGHTS_LAYER DNNL_ARG_WEIGHTS_0 + +/// Weights argument #1. +#define DNNL_ARG_WEIGHTS_1 34 +/// A special mnemonic for RNN weights applied to the recurrent input. +/// An alias for #DNNL_ARG_WEIGHTS_1. +#define DNNL_ARG_WEIGHTS_ITER DNNL_ARG_WEIGHTS_1 + +/// Weights argument #2. +#define DNNL_ARG_WEIGHTS_2 35 +/// A special mnemonic for RNN weights applied to the peephole weights. +/// An alias for #DNNL_ARG_WEIGHTS_2. +#define DNNL_ARG_WEIGHTS_PEEPHOLE DNNL_ARG_WEIGHTS_2 + +/// Weights argument #3. +#define DNNL_ARG_WEIGHTS_3 36 +/// A special mnemonic for RNN weights applied to the projection weights. +/// An alias for #DNNL_ARG_WEIGHTS_3. +#define DNNL_ARG_WEIGHTS_PROJECTION DNNL_ARG_WEIGHTS_3 + +/// Bias tensor argument. +#define DNNL_ARG_BIAS 41 + +/// Mean values tensor argument. +#define DNNL_ARG_MEAN 49 +/// Variance values tensor argument. +#define DNNL_ARG_VARIANCE 50 + +/// A special mnemonic for scale argument of normalization primitives. +#define DNNL_ARG_SCALE 51 +/// A special mnemonic for shift argument of normalization primitives. +#define DNNL_ARG_SHIFT 52 + +/// Workspace tensor argument. Workspace is used to pass information +/// from forward propagation to backward propagation computations. +#define DNNL_ARG_WORKSPACE 64 +/// Scratchpad (temporary storage) tensor argument. +#define DNNL_ARG_SCRATCHPAD 80 + +/// Gradient (diff) of the source argument #0. +#define DNNL_ARG_DIFF_SRC_0 129 +/// A special mnemonic for primitives that have a single diff source argument. +/// An alias for #DNNL_ARG_DIFF_SRC_0. +#define DNNL_ARG_DIFF_SRC DNNL_ARG_DIFF_SRC_0 +/// A special mnemonic for gradient (diff) of RNN input vector. An alias for +/// #DNNL_ARG_DIFF_SRC_0. +#define DNNL_ARG_DIFF_SRC_LAYER DNNL_ARG_DIFF_SRC_0 + +/// Gradient (diff) of the source argument #1. +#define DNNL_ARG_DIFF_SRC_1 130 +/// A special mnemonic for gradient (diff) of RNN input recurrent hidden state +/// vector. An alias for #DNNL_ARG_DIFF_SRC_1. +#define DNNL_ARG_DIFF_SRC_ITER DNNL_ARG_DIFF_SRC_1 + +/// Gradient (diff) of the source argument #2. +#define DNNL_ARG_DIFF_SRC_2 131 +/// A special mnemonic for gradient (diff) of RNN input recurrent cell state +/// vector. An alias for #DNNL_ARG_DIFF_SRC_1. +#define DNNL_ARG_DIFF_SRC_ITER_C DNNL_ARG_DIFF_SRC_2 + +/// Gradient (diff) of the source argument #3. +#define DNNL_ARG_DIFF_SRC_3 132 +/// A special mnemonic for gradient (diff) of RNN input recurrent cell attention +/// vector. An alias for #DNNL_ARG_DIFF_SRC_3. +#define DNNL_ARG_DIFF_AUGRU_ATTENTION DNNL_ARG_DIFF_SRC_3 + +/// Gradient (diff) of the destination argument #0. +#define DNNL_ARG_DIFF_DST_0 145 +/// A special mnemonic for primitives that have a single diff destination +/// argument. An alias for #DNNL_ARG_DIFF_DST_0. +#define DNNL_ARG_DIFF_DST DNNL_ARG_DIFF_DST_0 +/// A special mnemonic for gradient (diff) of RNN output vector. An alias for +/// #DNNL_ARG_DIFF_DST_0. +#define DNNL_ARG_DIFF_DST_LAYER DNNL_ARG_DIFF_DST_0 + +/// Gradient (diff) of the destination argument #1. +#define DNNL_ARG_DIFF_DST_1 146 +/// A special mnemonic for gradient (diff) of RNN input recurrent hidden state +/// vector. An alias for #DNNL_ARG_DIFF_DST_1. +#define DNNL_ARG_DIFF_DST_ITER DNNL_ARG_DIFF_DST_1 + +/// Gradient (diff) of the destination argument #2. +#define DNNL_ARG_DIFF_DST_2 147 +/// A special mnemonic for gradient (diff) of RNN input recurrent cell state +/// vector. An alias for #DNNL_ARG_DIFF_DST_2. +#define DNNL_ARG_DIFF_DST_ITER_C DNNL_ARG_DIFF_DST_2 + +/// Gradient (diff) of the weights argument #0. +#define DNNL_ARG_DIFF_WEIGHTS_0 161 +/// A special mnemonic for primitives that have a single diff weights +/// argument. Alias for #DNNL_ARG_DIFF_WEIGHTS_0. +#define DNNL_ARG_DIFF_WEIGHTS DNNL_ARG_DIFF_WEIGHTS_0 +/// A special mnemonic for diff of RNN weights applied to the layer input. An +/// alias for #DNNL_ARG_DIFF_WEIGHTS_0. +#define DNNL_ARG_DIFF_WEIGHTS_LAYER DNNL_ARG_DIFF_WEIGHTS_0 + +/// Gradient (diff) of the weights argument #1. +#define DNNL_ARG_DIFF_WEIGHTS_1 162 +/// A special mnemonic for diff of RNN weights applied to the recurrent input. +/// An alias for #DNNL_ARG_DIFF_WEIGHTS_1. +#define DNNL_ARG_DIFF_WEIGHTS_ITER DNNL_ARG_DIFF_WEIGHTS_1 + +/// Gradient (diff) of the weights argument #2. +#define DNNL_ARG_DIFF_WEIGHTS_2 163 +/// A special mnemonic for diff of RNN weights applied to the peephole weights. +/// An alias for #DNNL_ARG_DIFF_WEIGHTS_2. +#define DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE DNNL_ARG_DIFF_WEIGHTS_2 + +/// Gradient (diff) of the weights argument #3. +#define DNNL_ARG_DIFF_WEIGHTS_3 164 +/// A special mnemonic for diff of RNN weights applied to the projection +/// weights. An alias for #DNNL_ARG_DIFF_WEIGHTS_3. +#define DNNL_ARG_DIFF_WEIGHTS_PROJECTION DNNL_ARG_DIFF_WEIGHTS_3 + +/// Gradient (diff) of the bias tensor argument. +#define DNNL_ARG_DIFF_BIAS 169 + +/// A special mnemonic for scale argument of normalization primitives. +#define DNNL_ARG_DIFF_SCALE 255 +/// A special mnemonic for shift argument of normalization primitives. +#define DNNL_ARG_DIFF_SHIFT 256 + +/// Rounding mode seed for stochastic rounding +/// Single seed needed independently of how many arguments need stochastic rounding +#define DNNL_ARG_ATTR_ROUNDING_SEED 508 + +/// Dropout mask output buffer. +#define DNNL_ARG_ATTR_DROPOUT_MASK 509 + +/// Dropout probability value passed via a buffer. +#define DNNL_ARG_ATTR_DROPOUT_PROBABILITY 510 + +/// Dropout RNG seed value passed via a buffer. +#define DNNL_ARG_ATTR_DROPOUT_SEED 511 + +/// Output scaling factors provided at execution time. +#define DNNL_ARG_ATTR_OUTPUT_SCALES 513 + +/// Starting index for source arguments for primitives that take a variable +/// number of source arguments. +#define DNNL_ARG_MULTIPLE_SRC 1024 +/// Starting index for destination arguments for primitives that produce a +/// variable number of destination arguments. +#define DNNL_ARG_MULTIPLE_DST 2048 + +/// Scaling factors provided at execution time. +#define DNNL_ARG_ATTR_SCALES 4096 + +/// Zero points provided at execution time. +#define DNNL_ARG_ATTR_ZERO_POINTS 8192 + +/// Arguments for fused depthwise convolution. +/// See @ref dev_guide_attributes_post_ops_depthwise_fusion +#define DNNL_ARG_ATTR_POST_OP_DW 16384 + +/// Starting point for a binary post operation. +#define DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE 32768 + +/// Arguments for a binary post operation. Up to 32 arguments are supported. +/// See @ref dev_guide_attributes_post_ops_binary_fusion +#define DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) \ + (DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE * ((idx) + 1)) + +/// A structure that contains an index and a memory object, and is used to pass +/// arguments to dnnl_primitive_execute(). +typedef struct { + int arg; ///< An argument index, e.g. DNNL_ARG_SRC + dnnl_memory_t memory; ///< Input/output memory +} dnnl_exec_arg_t; + +/// @} dnnl_api_primitives_common + +/// @addtogroup dnnl_api_primitives_common +/// @{ + +/// Primitive descriptor query specification +/// +/// For generic function dnnl_primitive_desc_query(), the type of result must +/// agree with the queried argument. The correspondence table: +/// +/// Query kind | Type of query result +/// --------------------------------|----------------------------- +/// dnnl_query_*_engine | #dnnl_engine_t * +/// #dnnl_query_primitive_kind | #dnnl_primitive_kind_t * +/// dnnl_query_*_s32 | int * +/// dnnl_query_*_s64 | #dnnl_dim_t * (same as int64_t *) +/// dnnl_query_*_f32 | float * +/// dnnl_query_*_f64 | double * +/// dnnl_query_*_str | const char ** +/// dnnl_query_*_md | #const_dnnl_memory_desc_t * +/// dnnl_query_*_pd | #const_dnnl_primitive_desc_t * +/// dnnl_query_cache_blob_id | const uint8_t ** +/// dnnl_query_strides | const #dnnl_dims_t ** +/// dnnl_query_dilations | const #dnnl_dims_t ** +/// dnnl_query_padding_l | const #dnnl_dims_t ** +/// dnnl_query_padding_r | const #dnnl_dims_t ** +/// dnnl_query_flags | unsigned * +/// dnnl_query_alg_kind | #dnnl_alg_kind_t * +/// dnnl_query_factors | const float ** +/// dnnl_query_cell_kind | #dnnl_alg_kind_t * +/// dnnl_query_direction | #dnnl_rnn_direction_t * +/// dnnl_query_activation_kind | #dnnl_alg_kind_t * +/// dnnl_query_kernel | const #dnnl_dims_t ** +/// dnnl_query_dims | const #dnnl_dims_t ** +/// dnnl_query_data_type | #dnnl_data_type_t * +/// dnnl_query_padded_dims | const #dnnl_dims_t ** +/// dnnl_query_padded_offsets | const #dnnl_dims_t ** +/// dnnl_query_format_kind | #dnnl_format_kind_t * +/// dnnl_query_inner_blks | const #dnnl_dims_t ** +/// dnnl_query_inner_idxs | const #dnnl_dims_t ** +/// dnnl_query_sparse_encoding | #dnnl_sparse_encoding_t * +/// +/// @note +/// Rule of thumb: all opaque types and structures are returned by +/// reference. All numbers are returned by value. +/// +/// @warning +/// All returned references point to constant objects and are valid only +/// during the lifetime of the queried primitive descriptor. Returned objects +/// must not be destroyed by the user. If you need to keep the object longer +/// than the lifetime of the queried primitive descriptor, use +/// dnnl_primitive_desc_clone() to make a copy. +typedef enum { + dnnl_query_undef = 0, ///< no query + + dnnl_query_engine, ///< execution engine + dnnl_query_primitive_kind, ///< primitive kind + + dnnl_query_num_of_inputs_s32, ///< number of inputs expected + dnnl_query_num_of_outputs_s32, ///< number of outputs expected + + dnnl_query_time_estimate_f64, ///< runtime estimation (seconds) + dnnl_query_memory_consumption_s64, ///< memory consumption -- extra + /// (scratch) memory, additional to + /// all inputs and outputs memory + /// (bytes) + + dnnl_query_scratchpad_engine, ///< scratchpad engine -- engine to be used + /// for creating scratchpad memory + + dnnl_query_impl_info_str, ///< implementation name + + dnnl_query_reorder_src_engine, ///< source engine + dnnl_query_reorder_dst_engine, ///< destination engine + + dnnl_query_prop_kind, ///< propagation kind + + dnnl_query_cache_blob_id_size_s64, ///< size of cache blob ID in bytes + dnnl_query_cache_blob_id, ///< cache blob ID (pointer to array) + + dnnl_query_strides, ///< strides + dnnl_query_dilations, ///< dilations + dnnl_query_padding_l, ///< left padding + dnnl_query_padding_r, ///< right padding + dnnl_query_epsilon_f32, ///< epsilon + dnnl_query_flags, ///< flags + dnnl_query_alg_kind, ///< algorithm kind + dnnl_query_alpha_f32, ///< alpha + dnnl_query_beta_f32, ///< beta + dnnl_query_axis_s32, ///< axis + dnnl_query_local_size_s64, ///< LRN parameter local size + dnnl_query_k_f32, ///< LRN parameter K + dnnl_query_p_f32, ///< Reduction parameter P + dnnl_query_factors, ///< Resampling parameter factors + dnnl_query_cell_kind, ///< RNN parameter cell kind + dnnl_query_direction, ///< RNN parameter direction + dnnl_query_activation_kind, ///< RNN parameter activation kind + dnnl_query_kernel, ///< Pooling parameter kernel + dnnl_query_group_size_s64, ///< Shuffle parameter group size + + // memory descriptor section + dnnl_query_some_md = 128, ///< stub + dnnl_query_src_md, ///< source memory desc + dnnl_query_diff_src_md, ///< source gradient memory desc + dnnl_query_weights_md, ///< weights memory descriptor desc + dnnl_query_diff_weights_md, ///< weights grad. memory desc + dnnl_query_dst_md, ///< destination memory desc + dnnl_query_diff_dst_md, ///< destination grad. memory desc + dnnl_query_workspace_md, ///< workspace memory desc + dnnl_query_scratchpad_md, ///< scratchpad memory desc + dnnl_query_exec_arg_md = 255, ///< memory desc of an execute argument + + dnnl_query_ndims_s32, ///< number of dimensions + dnnl_query_dims, ///< vector of dimensions + dnnl_query_data_type, ///< data type + dnnl_query_submemory_offset_s64, ///< submemory offset + dnnl_query_padded_dims, ///< vector of padded dimensions + dnnl_query_padded_offsets, ///< vector of padded offsets + dnnl_query_format_kind, ///< format kind + dnnl_query_inner_nblks_s32, ///< number of innermost blocks + dnnl_query_inner_blks, ///< vector of sizes of the innermost blocks + dnnl_query_inner_idxs, ///< vector of logical indices of the blocks +#ifdef DNNL_EXPERIMENTAL_SPARSE + dnnl_query_sparse_encoding, ///< Sparse encoding + dnnl_query_nnz_s64, ///< Number of non-zero entries + dnnl_query_num_handles_s32, ///< Number of buffers required for a memory +/// descriptor +#endif + // Max value to prevent UB for internal use only dnnl_query_t + dnnl_query_max = 0x7fff, +} dnnl_query_t; + +/// @} dnnl_api_primitives_common + +/// @} dnnl_api_primitives + +/// @addtogroup dnnl_api_service +/// @{ + +/// Disable profiling completely +#define DNNL_JIT_PROFILE_NONE 0u + +/// Enable VTune Profiler integration +#define DNNL_JIT_PROFILE_VTUNE 1u + +/// Enable Linux perf integration via perfmap files +#define DNNL_JIT_PROFILE_LINUX_PERFMAP 2u + +/// Enable Linux perf integration via jitdump files +#define DNNL_JIT_PROFILE_LINUX_JITDUMP 4u + +/// Instruct Linux perf integration via jitdump files to use TSC. @ref +/// DNNL_JIT_PROFILE_LINUX_JITDUMP must be set too for this to take effect. +#define DNNL_JIT_PROFILE_LINUX_JITDUMP_USE_TSC 8u + +/// Enable Linux perf integration (both jitdump and perfmap) +#define DNNL_JIT_PROFILE_LINUX_PERF \ + (DNNL_JIT_PROFILE_LINUX_JITDUMP | DNNL_JIT_PROFILE_LINUX_PERFMAP) + +/// CPU instruction set flags +typedef enum { + /// Library choice of ISA (excepting those listed as initial support) + dnnl_cpu_isa_default = 0x0, + + /// Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1) + dnnl_cpu_isa_sse41 = 0x1, + + /// Intel Advanced Vector Extensions (Intel AVX) + dnnl_cpu_isa_avx = 0x3, + + /// Intel Advanced Vector Extensions 2 (Intel AVX2) + dnnl_cpu_isa_avx2 = 0x7, + + /// Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support + dnnl_cpu_isa_avx2_vnni = 0xf, + + /// Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) + /// with 8-bit integer, float16 and bfloat16 support + dnnl_cpu_isa_avx2_vnni_2 = 0x1f, + + /// Intel AVX-512 subset for Intel Xeon Scalable processor family + /// and Intel Core processor family. + dnnl_cpu_isa_avx512_core = 0x27, + + /// Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support + /// for Intel Xeon Scalable processor family + /// and Intel Core processor family. + dnnl_cpu_isa_avx512_core_vnni = 0x67, + + /// Intel AVX-512, Intel DL Boost and bfloat16 support + /// for Intel Xeon Scalable processor family + /// and Intel Core processor family. + dnnl_cpu_isa_avx512_core_bf16 = 0xe7, + + /// Intel AVX-512 with float16, Intel DL Boost and bfloat16 support + /// for Intel Xeon Scalable processor family + /// and Intel Core processor family. + // TODO: Align avx10_1 values to internal representation. + dnnl_cpu_isa_avx10_1_512 = 0x1ef, + /// @copydoc dnnl_cpu_isa_avx10_1_512 + dnnl_cpu_isa_avx512_core_fp16 = dnnl_cpu_isa_avx10_1_512, + + /// Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and + /// Intel AMX with 8-bit integer and bfloat16 support + // TODO: Align avx10_1 values to internal representation. + dnnl_cpu_isa_avx10_1_512_amx = 0xfef, + /// @copydoc dnnl_cpu_isa_avx10_1_512_amx + dnnl_cpu_isa_avx512_core_amx = dnnl_cpu_isa_avx10_1_512_amx, + + /// Intel AVX-512 with float16, Intel DL Boost and bfloat16 support and + /// Intel AMX with 8-bit integer, bfloat16 and float16 support + // TODO: Align avx10_1 values to internal representation. + dnnl_cpu_isa_avx10_1_512_amx_fp16 = 0x1fef, + /// @copydoc dnnl_cpu_isa_avx10_1_512_amx_fp16 + dnnl_cpu_isa_avx512_core_amx_fp16 = dnnl_cpu_isa_avx10_1_512_amx_fp16, +} dnnl_cpu_isa_t; + +/// CPU ISA hints flags +typedef enum { + /// No hints (use default features) + dnnl_cpu_isa_no_hints = 0x0, + + /// Prefer to exclusively use Ymm registers for computations + dnnl_cpu_isa_prefer_ymm = 0x1, +} dnnl_cpu_isa_hints_t; + +/// @} dnnl_api_service + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif /* ONEAPI_DNNL_TYPES_H */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.h new file mode 100644 index 0000000000000000000000000000000000000000..50c8fd4a96e252f2f06d5cb4ad3012a407e14915 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.h @@ -0,0 +1,337 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// ukernel C API + +#ifndef ONEAPI_DNNL_DNNL_UKERNEL_H +#define ONEAPI_DNNL_DNNL_UKERNEL_H + +#include "oneapi/dnnl/dnnl.h" +#include "oneapi/dnnl/dnnl_ukernel_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_ukernel +/// @{ + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +/// Creates a ukernel attributes memory storage. +/// +/// @param attr_params Output ukernel attributes memory storage. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ukernel_attr_params_create( + dnnl_ukernel_attr_params_t *attr_params); + +/// Sets post-operations arguments to a storage. +/// +/// @param attr_params Memory pointers storage object. +/// @param post_ops_args A pointer to pointers of post_ops storages. Expected to +/// be packed together. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_post_ops_args( + dnnl_ukernel_attr_params_t attr_params, const void **post_ops_args); + +/// Sets tensor A scales argument to a storage. +/// +/// @param attr_params Memory pointers storage object. +/// @param a_scales Pointer to the scales storage. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_A_scales( + dnnl_ukernel_attr_params_t attr_params, const void *a_scales); + +/// Sets tensor B scales argument to a storage. +/// +/// If `dnnl_brgemm_set_B_scales` used mask of 2, then at least N values of +/// selected data type are expected. +/// +/// @param attr_params Memory pointers storage object. +/// @param b_scales Pointer to the scales storage. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_B_scales( + dnnl_ukernel_attr_params_t attr_params, const void *b_scales); + +/// Sets tensor D scales argument to a storage. +/// +/// @param attr_params Memory pointers storage object. +/// @param d_scales Pointer to the scales storage. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_D_scales( + dnnl_ukernel_attr_params_t attr_params, const void *d_scales); + +/// Destroys a ukernel attributes memory storage. +/// +/// @param attr_params Memory pointers storage object to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_ukernel_attr_params_destroy( + dnnl_ukernel_attr_params_t attr_params); + +/// @addtogroup dnnl_api_ukernel_brgemm +/// @{ + +/// Creates a BRGeMM ukernel object. Operates by the following formula: +/// `C = [A x B]`. +/// +/// @param brgemm Output BRGeMM ukernel object. +/// @param M Dimension M of tensor A. +/// @param N Dimension N of tensor B. +/// @param K Dimension K of tensors A and B. +/// @param batch_size Number of batches to process. +/// @param lda Leading dimension of tensor A. +/// @param ldb Leading dimension of tensor B. +/// @param ldc Leading dimension of tensor C. +/// @param a_dt Data type of tensor A. +/// @param b_dt Data type of tensor B. +/// @param c_dt Data type of tensor C. Must be dnnl_f32. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_create(dnnl_brgemm_t *brgemm, dnnl_dim_t M, + dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t batch_size, dnnl_dim_t lda, + dnnl_dim_t ldb, dnnl_dim_t ldc, dnnl_data_type_t a_dt, + dnnl_data_type_t b_dt, dnnl_data_type_t c_dt); + +/// Sets adding an intermediate result to the output tensor C instead of +/// writing: `C += [A x B]`. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param add_C Value to indicate addition. Can be `0` to skip addition, and +/// `1` to apply addition. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_set_add_C(dnnl_brgemm_t brgemm, int add_C); + +/// Sets post-operations to a BRGeMM ukernel object: `D = post-operations(C)`. +/// +/// Post-operations applies if one of the following holds: +/// * Non-empty attributes are specified. +/// * Output data type `d_dt` is different from accumulation data type `c_dt`. +/// +/// If any of conditions happens, the final call of the accumulation chain +/// must be `dnnl_brgemm_execute_postops`, and `dnnl_brgemm_execute`, otherwise. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param ldd Leading dimension of tensor D. +/// @param d_dt Data type of tensor D. +/// @param post_ops Primitive post operations attribute to extend the kernel +/// operations. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_set_post_ops(dnnl_brgemm_t brgemm, + dnnl_dim_t ldd, dnnl_data_type_t d_dt, const_dnnl_post_ops_t post_ops); + +/// Sets tensor A scales mask to a BRGeMM ukernel object. +/// +/// For quantization flavor tensor A scales apply to accumulation buffer once C +/// is ready. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param a_scale_mask Tensor A scale mask. Can be `0` only. +dnnl_status_t DNNL_API dnnl_brgemm_set_A_scales( + dnnl_brgemm_t brgemm, int a_scale_mask); + +/// Sets tensor B scales mask to a BRGeMM ukernel object. +/// +/// For quantization flavor tensor B scales apply to accumulation buffer once C +/// is ready. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param b_scale_mask Tensor B scale mask. Can be `0` and `2` only. +dnnl_status_t DNNL_API dnnl_brgemm_set_B_scales( + dnnl_brgemm_t brgemm, int b_scale_mask); + +/// Sets tensor D scales mask to a BRGeMM ukernel object. +/// +/// For quantization flavor tensor D scales apply after all post-ops are +/// applied. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param d_scale_mask Tensor D scale mask. Can be `0` only. +dnnl_status_t DNNL_API dnnl_brgemm_set_D_scales( + dnnl_brgemm_t brgemm, int d_scale_mask); + +/// Finalizes initialization of a BRGeMM ukernel object. +/// +/// This step is mandatory to query information from the object. +/// +/// @param brgemm Output BRGeMM ukernel object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_finalize(dnnl_brgemm_t brgemm); + +/// Returns the packing type expected by a tensor B of a BRGeMM ukernel object. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param pack_type Output packing type. Can be `dnnl_brgemm_no_pack` if +/// packing is not expected, and `dnnl_brgemm_pack_32`, otherwise. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type( + const_dnnl_brgemm_t brgemm, dnnl_pack_type_t *pack_type); + +/// Returns the size of a scratchpad memory needed for the BRGeMM ukernel +/// object. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param size Output size of a buffer required for the BRGeMM ukernel object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_get_scratchpad_size( + const_dnnl_brgemm_t brgemm, size_t *size); + +/// Returns the flag indicating when the call to `dnnl_brgemm_execute_postops` +/// is valid. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param valid The flag indicating if `dnnl_brgemm_execute_postops` is valid +/// for a given ukernel object. `1` is for valid and `0`, otherwise. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_is_execute_postops_valid( + const_dnnl_brgemm_t brgemm, int *valid); + +/// Initializes the hardware-specific context. If no initialization required, +/// returns the success status. +/// +/// @param brgemm BRGeMM ukernel object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_set_hw_context(const_dnnl_brgemm_t brgemm); + +/// Releases the hardware-specific context. Must be used after all the execution +/// calls to BRGeMM ukernel objects. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_release_hw_context(); + +/// Generates an executable part of BRGeMM ukernel object. +/// @param brgemm BRGeMM ukernel object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_generate(dnnl_brgemm_t brgemm); + +/// Executes a BRGeMM ukernel object. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param A_ptr Base pointer to a tensor A. +/// @param B_ptr Base pointer to a tensor B. +/// @param A_B_offsets Pointer to the set of tensor A and tensor B offsets for +/// each batch; the set must be contiguous in memory. Single batch should +/// supply offsets for both tensors A and B simultaneously. The number of +/// batches must coincide with the `batch_size` value passed at the creation +/// stage. +/// @param C_ptr Pointer to a tensor C (accumulation buffer). +/// @param scratchpad_ptr Pointer to a scratchpad buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_execute(const_dnnl_brgemm_t brgemm, + const void *A_ptr, const void *B_ptr, const dnnl_dim_t *A_B_offsets, + void *C_ptr, void *scratchpad_ptr); + +/// Executes a BRGeMM ukernel object with post operations. +/// +/// @param brgemm BRGeMM ukernel object. +/// @param A Base pointer to a tensor A. +/// @param B Base pointer to a tensor B. +/// @param A_B_offsets Pointer to a set of tensor A and tensor B offsets for +/// each batch. A set must be contiguous in memory. A single batch should +/// supply offsets for both tensors A and B simultaneously. The number of +/// batches must coincide with the `batch_size` value passed at the creation +/// stage. +/// @param C_ptr Pointer to a tensor C (accumulation buffer). +/// @param D_ptr Pointer to a tensor D (output buffer). +/// @param scratchpad_ptr Pointer to a scratchpad buffer. +/// @param attr_params Ukernel attributes memory storage. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_execute_postops(const_dnnl_brgemm_t brgemm, + const void *A, const void *B, const dnnl_dim_t *A_B_offsets, + const void *C_ptr, void *D_ptr, void *scratchpad_ptr, + const_dnnl_ukernel_attr_params_t attr_params); + +/// Destroys a BRGeMM ukernel object. +/// +/// @param brgemm BRGeMM ukernel object to destroy. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_brgemm_destroy(dnnl_brgemm_t brgemm); + +/// Creates a transform object. +/// +/// @param transform Output transform object. +/// @param K Dimension K. +/// @param N Dimension N. +/// @param in_pack_type Input packing type. Must be one of +/// `dnnl_pack_type_no_trans`, or `dnnl_pack_type_trans`. +/// @param in_ld Input leading dimension. +/// @param out_ld Output leading dimension. When packing data, it specifies a +/// block by N dimension. +/// @param in_dt Input data type. +/// @param out_dt Output data type. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_transform_create(dnnl_transform_t *transform, + dnnl_dim_t K, dnnl_dim_t N, dnnl_pack_type_t in_pack_type, + dnnl_dim_t in_ld, dnnl_dim_t out_ld, dnnl_data_type_t in_dt, + dnnl_data_type_t out_dt); + +/// Generates an executable part of transform object. +/// @param transform Transform object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_transform_generate(dnnl_transform_t transform); + +/// Executes a transform object. +/// +/// @param transform Transform object. +/// @param in_ptr Pointer to an input buffer. +/// @param out_ptr Pointer to an output buffer. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_transform_execute( + const_dnnl_transform_t transform, const void *in_ptr, void *out_ptr); + +/// Destroys a transform object. +/// +/// @param transform Transform object. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_transform_destroy(dnnl_transform_t transform); + +/// @} dnnl_api_ukernel_brgemm + +#endif + +/// @} dnnl_api_ukernel + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif /* ONEAPI_DNNL_DNNL_UKERNEL_H */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.hpp b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..284177f43e255ca70546beb3e52fa54ed1e15e05 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.hpp @@ -0,0 +1,465 @@ +/******************************************************************************* +* Copyright 2024-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// ukernel C++ API + +#ifndef ONEAPI_DNNL_DNNL_UKERNEL_HPP +#define ONEAPI_DNNL_DNNL_UKERNEL_HPP + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_ukernel.h" + +/// @addtogroup dnnl_api oneDNN API +/// @{ + +/// oneDNN namespace +namespace dnnl { + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +/// @addtogroup dnnl_api_utils +/// @{ + +/// @cond DO_NOT_DOCUMENT_THIS + +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_brgemm_t p) { + return dnnl_brgemm_destroy(p); + } +}; + +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_transform_t p) { + return dnnl_transform_destroy(p); + } +}; + +template <> +struct handle_traits { + static dnnl_status_t destructor(dnnl_ukernel_attr_params_t p) { + return dnnl_ukernel_attr_params_destroy(p); + } +}; + +/// @endcond + +/// @} dnnl_api_utils + +#endif + +/// @addtogroup dnnl_api_ukernel Ukernels +/// Collection of ukernels +/// @{ + +/// ukernel namespace +namespace ukernel { + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +/// @addtogroup dnnl_api_ukernel_utils ukernel utils +/// ukernel utility functions +/// @{ + +/// Packing specification +enum class pack_type { + /// Undefined pack type. A guard value. + undef = dnnl_pack_type_undef, + /// Plain, not transposed layout. Similar to format_tag::ab. + no_trans = dnnl_pack_type_no_trans, + /// Plain, transposed layout. Similar to format_tag::ba. + trans = dnnl_pack_type_trans, + /// Packed by 32 bits along K dimension layout. + pack32 = dnnl_pack_type_pack32, +}; + +/// Ukernel attributes memory storage +struct attr_params : public handle { + /// Constructs a ukernel attributes memory storage. + attr_params() { + dnnl_ukernel_attr_params_t c_params = nullptr; + dnnl_status_t status = dnnl_ukernel_attr_params_create(&c_params); + error::wrap_c_api( + status, "could not create an attributes memory storage"); + reset(c_params); + } + + /// Sets post-operations arguments to a storage. + /// + /// @param post_ops_args Pointer to pointers of post_ops storages. + /// Expected to be packed together. + void set_post_ops_args(const void **post_ops_args) { + dnnl_status_t status = dnnl_ukernel_attr_params_set_post_ops_args( + get(), post_ops_args); + if (status != dnnl_success) + error::wrap_c_api( + status, "could not set post operations arguments"); + } + + /// Sets tensor A scales arguments to a storage. + /// + /// @param a_scales Pointer to scales storage. + void set_A_scales(const void *a_scales) { + dnnl_status_t status + = dnnl_ukernel_attr_params_set_A_scales(get(), a_scales); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set A scales argument"); + } + + /// Sets tensor B scales arguments to a storage. + /// + /// If @ref attr_params::set_B_scales used mask of 2, then at + /// least N values of selected data type are expected. + /// + /// @param b_scales Pointer to scales storage. + void set_B_scales(const void *b_scales) { + dnnl_status_t status + = dnnl_ukernel_attr_params_set_B_scales(get(), b_scales); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set B scales argument"); + } + + /// Sets tensor D scales arguments to a storage. + /// + /// @param d_scales Pointer to scales storage. + void set_D_scales(const void *d_scales) { + dnnl_status_t status + = dnnl_ukernel_attr_params_set_D_scales(get(), d_scales); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set D scales argument"); + } +}; +/// @} dnnl_api_ukernel_utils + +/// @addtogroup dnnl_api_ukernel_brgemm BRGeMM ukernel +/// BRGeMM ukernel routines +/// @{ + +/// BRGeMM ukernel +struct brgemm : public handle { + /// Default constructor. Produces an empty object. + brgemm() = default; + + /// Constructs a BRGeMM ukernel object. Operates by the following formula: + /// `C = [A x B]`. + /// + /// @param M Dimension M of tensor A. + /// @param N Dimension N of tensor B. + /// @param K Dimension K of tensors A and B. + /// @param batch_size Number of batches to process. + /// @param lda Leading dimension of tensor A. + /// @param ldb Leading dimension of tensor B. + /// @param ldc Leading dimension of tensor C. + /// @param a_dt Data type of tensor A. + /// @param b_dt Data type of tensor B. + /// @param c_dt Data type of tensor C. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + brgemm(memory::dim M, memory::dim N, memory::dim K, memory::dim batch_size, + memory::dim lda, memory::dim ldb, memory::dim ldc, + memory::data_type a_dt, memory::data_type b_dt, + memory::data_type c_dt, bool allow_empty = false) { + + dnnl_brgemm_t brgemm = nullptr; + dnnl_status_t status = dnnl_brgemm_create(&brgemm, M, N, K, batch_size, + lda, ldb, ldc, memory::convert_to_c(a_dt), + memory::convert_to_c(b_dt), memory::convert_to_c(c_dt)); + + if (!allow_empty) + error::wrap_c_api( + status, "could not create a BRGeMM ukernel object"); + reset(brgemm); + } + + /// Sets adding an intermediate result to the output tensor C instead of + /// writing: `C += [A x B]`. + /// + /// @param add_C Value to indicate addition. `false` to skip addition, and + /// `true` to apply addition. + void set_add_C(bool add_C) { + dnnl_status_t status + = dnnl_brgemm_set_add_C(get(), static_cast(add_C)); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set add_C attribute"); + } + + /// Sets post-operations to a BRGeMM ukernel object: + /// `D = post-operations(C)`. + /// + /// Post-operations applies if one of the following holds: + /// * Non-empty post-operations are specified. + /// * Output data type `d_dt` is different from accumulation data type + /// `c_dt`. + /// + /// @param ldd Leading dimension of tensor D. + /// @param d_dt Data type of tensor D. + /// @param po Primitive post-operation attributes to extend the kernel + /// operations. + void set_post_ops(memory::dim ldd, memory::data_type d_dt, + const post_ops &po = default_post_ops()) { + dnnl_status_t status = dnnl_brgemm_set_post_ops( + get(), ldd, memory::convert_to_c(d_dt), po.get()); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set post operations"); + } + + /// Sets tensor A scales mask to a BRGeMM ukernel object. + /// + /// For quantization flavor tensor A scales apply to accumulation buffer + /// once C is ready. + /// + /// @param a_scale_mask Tensor A scale mask. Can be `0` only. + void set_A_scales(int a_scale_mask) { + dnnl_status_t status = dnnl_brgemm_set_A_scales(get(), a_scale_mask); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set A scales"); + } + + /// Sets tensor B scales mask to a BRGeMM ukernel object. + /// + /// For quantization flavor tensor B scales apply to accumulation buffer + /// once C is ready. + /// + /// @param b_scale_mask Tensor B scale mask. Can be `0` and `2` only. + void set_B_scales(int b_scale_mask) { + dnnl_status_t status = dnnl_brgemm_set_B_scales(get(), b_scale_mask); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set B scales"); + } + + /// Sets tensor D scales mask to a BRGeMM ukernel object. + /// + /// For quantization flavor tensor D scales apply after all post-ops are + /// applied. + /// + /// @param d_scale_mask Tensor D scale mask. Can be `0` only. + void set_D_scales(int d_scale_mask) { + dnnl_status_t status = dnnl_brgemm_set_D_scales(get(), d_scale_mask); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set D scales"); + } + + /// Finalizes initialization of a BRGeMM ukernel object. + /// + /// This step must be performed prior to querying information from the + /// object. + void finalize() { + dnnl_status_t status = dnnl_brgemm_finalize(get()); + if (status != dnnl_success) + error::wrap_c_api(status, "could not finalize an object"); + } + + /// Returns the packing type expected by a tensor B of a BRGeMM ukernel + /// object. + pack_type get_B_pack_type() const { + dnnl_pack_type_t c_pack_type; + dnnl_status_t status = dnnl_brgemm_get_B_pack_type(get(), &c_pack_type); + if (status != dnnl_success) + error::wrap_c_api(status, "could not query B pack type"); + + return static_cast(c_pack_type); + } + + /// Returns the size of a scratchpad memory needed for the BRGeMM ukernel + /// object. + size_t get_scratchpad_size() const { + size_t size; + dnnl_status_t status = dnnl_brgemm_get_scratchpad_size(get(), &size); + if (status != dnnl_success) + error::wrap_c_api(status, + "could not query a scratchpad size from a BRGeMM ukernel " + "object"); + return size; + } + + /// Returns the flag indicating when the call to execute with post + /// operations is valid. + /// + /// `True` is for a valid call, `false`, otherwise. + bool is_execute_postops_valid() const { + int valid; + dnnl_status_t status + = dnnl_brgemm_is_execute_postops_valid(get(), &valid); + if (status != dnnl_success) + error::wrap_c_api(status, + "could not query a flag for execute postops from a BRGeMM " + "ukernel object"); + return static_cast(valid); + } + + /// Initializes the hardware-specific context. Affects the global state for + /// all BRGeMM ukernel objects. If no initialization required, returns. + void set_hw_context() const { + dnnl_status_t status = dnnl_brgemm_set_hw_context(get()); + if (status != dnnl_success) + error::wrap_c_api(status, "could not set hardware context"); + } + + /// Releases the hardware-specific context. Affects the global state for + /// all BRGeMM ukernel objects. Must be used after all the execution calls + /// to BRGeMM ukernel objects. + static void release_hw_context() { + dnnl_status_t status = dnnl_brgemm_release_hw_context(); + if (status != dnnl_success) + error::wrap_c_api(status, "could not release hardware context"); + } + + /// Generates an executable part of BRGeMM ukernel object. + void generate() { + dnnl_status_t status = dnnl_brgemm_generate(get()); + if (status != dnnl_success) + error::wrap_c_api(status, "could not generate a kernel"); + } + + /// Executes a BRGeMM ukernel object. + /// + /// @param A Base pointer to a tensor A. + /// @param B Base pointer to a tensor B. + /// @param A_B_offsets Vector of pairs of tensors A and B offsets for + /// each batch. The number of batches must coincide with the + /// `batch_size` value passed at object construction stage. + /// @param C Pointer to a tensor C (accumulation buffer). + /// @param scratchpad Pointer to a scratchpad buffer. + void execute(const void *A, const void *B, + const std::vector> &A_B_offsets, + void *C, void *scratchpad) const { + // TODO: export batch_element to C API later for user to fill it and + // pass directly to the call. + dnnl_status_t status = dnnl_brgemm_execute(get(), A, B, + (const dnnl_dim_t *)A_B_offsets.data(), C, scratchpad); + if (status != dnnl_success) + error::wrap_c_api( + status, "could not execute a BRGeMM ukernel object"); + } + + /// Executes a BRGeMM ukernel object with post operations. + /// + /// @param A Base pointer to a tensor A. + /// @param B Base pointer to a tensor B. + /// @param A_B_offsets Vector of pairs of tensors A and B offsets for + /// each batch. The number of batches must coincide with the + /// `batch_size` value passed at object construction stage. + /// @param C Pointer to a tensor C (accumulation buffer). + /// @param D Pointer to a tensor D (output buffer). + /// @param scratchpad Pointer to a scratchpad buffer. + /// @param params Post-op memory arguments. Must be passed If binary + /// post-op or scales were set. + void execute(const void *A, const void *B, + const std::vector> &A_B_offsets, + const void *C, void *D, void *scratchpad, + const attr_params ¶ms = default_attr_params()) const { + // TODO: export batch_element to C API later for user to fill it and + // pass directly to the call. + dnnl_status_t status = dnnl_brgemm_execute_postops(get(), A, B, + (const dnnl_dim_t *)A_B_offsets.data(), C, D, scratchpad, + params.get()); + if (status != dnnl_success) + error::wrap_c_api( + status, "could not execute a BRGeMM ukernel object"); + } + + /// Returns a constant reference to a static instance of default constructed + /// primitive post-operations attribute. + static const post_ops &default_post_ops() { + static const post_ops po; + return po; + } + + /// Returns a constant reference to a static instance of default constructed + /// ukernel attributes parameters. + static const attr_params &default_attr_params() { + static const attr_params ap; + return ap; + } +}; +/// @} dnnl_api_ukernel_brgemm + +/// @addtogroup dnnl_api_ukernel_transform Transform ukernel +/// Transform routines +/// @{ + +/// Transform ukernel +struct transform : public handle { + /// Default constructor. Produces an empty object. + transform() = default; + + /// Constructs a transform object. + /// + /// @param K Dimension K. + /// @param N Dimension N. + /// @param in_pack_type Input packing type. Must be one of + /// `pack_type::no_trans`, or `pack_type::trans`. + /// @param in_ld Input leading dimension. + /// @param out_ld Output leading dimension. Specifies a block by N dimension + /// during data packing. + /// @param in_dt Input data type. + /// @param out_dt Output data type. + /// @param allow_empty A flag signifying whether construction is + /// allowed to fail without throwing an exception. In this case an + /// empty object will be produced. This flag is optional and + /// defaults to false. + transform(memory::dim K, memory::dim N, pack_type in_pack_type, + memory::dim in_ld, memory::dim out_ld, memory::data_type in_dt, + memory::data_type out_dt, bool allow_empty = false) { + + dnnl_transform_t transform = nullptr; + dnnl_status_t status = dnnl_transform_create(&transform, K, N, + static_cast(in_pack_type), in_ld, out_ld, + memory::convert_to_c(in_dt), memory::convert_to_c(out_dt)); + + if (!allow_empty) + error::wrap_c_api(status, + "could not create a BRGeMM ukernel packing B object"); + reset(transform); + } + + /// Generates an executable part of transform object. + void generate() { + dnnl_status_t status = dnnl_transform_generate(get()); + if (status != dnnl_success) + error::wrap_c_api(status, + "could not generate a BRGeMM ukernel packing B object"); + } + + /// Executes a transform object. + /// + /// @param in Pointer to an input buffer. + /// @param out Pointer to an output buffer. + void execute(const void *in, void *out) const { + dnnl_status_t status = dnnl_transform_execute(get(), in, out); + if (status != dnnl_success) + error::wrap_c_api(status, + "could not execute a BRGeMM ukernel packing B object"); + } +}; + +/// @} dnnl_api_ukernel_transform + +#endif + +} // namespace ukernel + +/// @} dnnl_api_ukernel + +} // namespace dnnl + +/// @} dnnl_api + +#endif /* ONEAPI_DNNL_DNNL_UKERNEL_HPP */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel_types.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel_types.h new file mode 100644 index 0000000000000000000000000000000000000000..012edfd1278b148e3d8d434241d9c59f46585401 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel_types.h @@ -0,0 +1,93 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// ukernel C API types definitions + +#ifndef ONEAPI_DNNL_DNNL_UKERNEL_TYPES_H +#define ONEAPI_DNNL_DNNL_UKERNEL_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "oneapi/dnnl/dnnl_types.h" + +/// @addtogroup dnnl_api +/// @{ + +/// @addtogroup dnnl_api_ukernel +/// @{ + +#ifdef DNNL_EXPERIMENTAL_UKERNEL + +/// Packing specification +typedef enum { + /// Undefined pack type. A guard value. + dnnl_pack_type_undef = 0, + /// Plain, not transposed layout. Similar to format_tag::ab. + dnnl_pack_type_no_trans, + /// Plain, transposed layout. Similar to format_tag::ba. + dnnl_pack_type_trans, + /// Packed by 32 bits along K dimension layout. + dnnl_pack_type_pack32, +} dnnl_pack_type_t; + +/// @struct dnnl_ukernel_attr_params +/// An opaque structure to describe ukernel attributes memory storage. +struct dnnl_ukernel_attr_params; + +/// A ukernel attributes memory storage handle. +typedef struct dnnl_ukernel_attr_params *dnnl_ukernel_attr_params_t; + +/// A constant ukernel attributes memory storage handle. +typedef const struct dnnl_ukernel_attr_params *const_dnnl_ukernel_attr_params_t; + +/// @addtogroup dnnl_api_ukernel_brgemm +/// @{ + +/// @struct dnnl_brgemm +/// An opaque structure to describe a brgemm ukernel. +struct dnnl_brgemm; + +/// A brgemm ukernel handle. +typedef struct dnnl_brgemm *dnnl_brgemm_t; + +/// A constant brgemm ukernel handle. +typedef const struct dnnl_brgemm *const_dnnl_brgemm_t; + +/// @struct dnnl_transform +/// An opaque structure to describe a transform routine. +struct dnnl_transform; + +/// A transform routine handle. +typedef struct dnnl_transform *dnnl_transform_t; + +/// A constant transform routine handle. +typedef const struct dnnl_transform *const_dnnl_transform_t; + +/// @} dnnl_api_ukernel_brgemm +#endif + +/// @} dnnl_api_ukernel + +/// @} dnnl_api + +#ifdef __cplusplus +} +#endif + +#endif /* ONEAPI_DNNL_DNNL_UKERNEL_TYPES_H */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_version.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_version.h new file mode 100644 index 0000000000000000000000000000000000000000..a518ac011ab0dc9bd2cc27648561d033c8e70b4e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_version.h @@ -0,0 +1,33 @@ +/******************************************************************************* +* Copyright 2019-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_VERSION_H +#define ONEAPI_DNNL_DNNL_VERSION_H + +// clang-format off + +/// Major version +#define DNNL_VERSION_MAJOR 3 + +/// Minor version +#define DNNL_VERSION_MINOR 7 + +/// Patch version +#define DNNL_VERSION_PATCH 1 + +// clang-format on + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_version_hash.h b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_version_hash.h new file mode 100644 index 0000000000000000000000000000000000000000..5b9418c6b8257da66ca707e1fdbfd68a78928888 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/oneapi/dnnl/dnnl_version_hash.h @@ -0,0 +1,31 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ONEAPI_DNNL_DNNL_VERSION_HASH_H +#define ONEAPI_DNNL_DNNL_VERSION_HASH_H + +// clang-format off + +/// Note: this macro and header file were moved to a separate instance to avoid +/// incremental build issues as moving from commit to commit would trigger a +/// complete library rebuild. Including a generated header file in a single +/// translation unit makes this problem go away. +/// Git commit hash +#define DNNL_VERSION_HASH "8d263e693366ef8db40acc569cc7d8edf644556d" + +// clang-format on + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/class.h b/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/class.h new file mode 100644 index 0000000000000000000000000000000000000000..b990507d629b4260d66d51e23a7f34a0fa465c9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/class.h @@ -0,0 +1,767 @@ +/* + pybind11/detail/class.h: Python C API implementation details for py::class_ + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include + +#include "exception_translation.h" + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +#if !defined(PYPY_VERSION) +# define PYBIND11_BUILTIN_QUALNAME +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) +#else +// In PyPy, we still set __qualname__ so that we can produce reliable function type +// signatures; in CPython this macro expands to nothing: +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) \ + setattr((PyObject *) obj, "__qualname__", nameobj) +#endif + +inline std::string get_fully_qualified_tp_name(PyTypeObject *type) { +#if !defined(PYPY_VERSION) + return type->tp_name; +#else + auto module_name = handle((PyObject *) type).attr("__module__").cast(); + if (module_name == PYBIND11_BUILTINS_MODULE) + return type->tp_name; + else + return std::move(module_name) + "." + type->tp_name; +#endif +} + +inline PyTypeObject *type_incref(PyTypeObject *type) { + Py_INCREF(type); + return type; +} + +#if !defined(PYPY_VERSION) + +/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. +extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { + return PyProperty_Type.tp_descr_get(self, cls, cls); +} + +/// `pybind11_static_property.__set__()`: Just like the above `__get__()`. +extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { + PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); + return PyProperty_Type.tp_descr_set(self, cls, value); +} + +// Forward declaration to use in `make_static_property_type()` +inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type); + +/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` + methods are modified to always use the object type instead of a concrete instance. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + constexpr auto *name = "pybind11_static_property"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) { + pybind11_fail("make_static_property_type(): error allocating type!"); + } + + heap_type->ht_name = name_obj.inc_ref().ptr(); +# ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +# endif + + auto *type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyProperty_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + type->tp_descr_get = pybind11_static_get; + type->tp_descr_set = pybind11_static_set; + +# if PY_VERSION_HEX >= 0x030C0000 + // Since Python-3.12 property-derived types are required to + // have dynamic attributes (to set `__doc__`) + enable_dynamic_attributes(heap_type); +# endif + + if (PyType_Ready(type) < 0) { + pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); + } + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +#else // PYPY + +/** PyPy has some issues with the above C API, so we evaluate Python code instead. + This function will only be called once so performance isn't really a concern. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + auto d = dict(); + PyObject *result = PyRun_String(R"(\ +class pybind11_static_property(property): + def __get__(self, obj, cls): + return property.__get__(self, cls, cls) + + def __set__(self, obj, value): + cls = obj if isinstance(obj, type) else type(obj) + property.__set__(self, cls, value) +)", + Py_file_input, + d.ptr(), + d.ptr()); + if (result == nullptr) + throw error_already_set(); + Py_DECREF(result); + return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); +} + +#endif // PYPY + +/** Types with static properties need to handle `Type.static_prop = x` in a specific way. + By default, Python replaces the `static_property` itself, but for wrapped C++ types + we need to call `static_property.__set__()` in order to propagate the new value to + the underlying C++ data structure. */ +extern "C" inline int pybind11_meta_setattro(PyObject *obj, PyObject *name, PyObject *value) { + // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw + // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + + // The following assignment combinations are possible: + // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` + // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` + // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment + auto *const static_prop = (PyObject *) get_internals().static_property_type; + const auto call_descr_set = (descr != nullptr) && (value != nullptr) + && (PyObject_IsInstance(descr, static_prop) != 0) + && (PyObject_IsInstance(value, static_prop) == 0); + if (call_descr_set) { + // Call `static_property.__set__()` instead of replacing the `static_property`. +#if !defined(PYPY_VERSION) + return Py_TYPE(descr)->tp_descr_set(descr, obj, value); +#else + if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { + Py_DECREF(result); + return 0; + } else { + return -1; + } +#endif + } else { + // Replace existing attribute. + return PyType_Type.tp_setattro(obj, name, value); + } +} + +/** + * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing + * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, + * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here + * to do a special case bypass for PyInstanceMethod_Types. + */ +extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + if (descr && PyInstanceMethod_Check(descr)) { + Py_INCREF(descr); + return descr; + } + return PyType_Type.tp_getattro(obj, name); +} + +/// metaclass `__call__` function that is used to create all pybind11 objects. +extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, PyObject *kwargs) { + + // use the default metaclass call to create/initialize the object + PyObject *self = PyType_Type.tp_call(type, args, kwargs); + if (self == nullptr) { + return nullptr; + } + + // Ensure that the base __init__ function(s) were called + values_and_holders vhs(self); + for (const auto &vh : vhs) { + if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) { + PyErr_Format(PyExc_TypeError, + "%.200s.__init__() must be called when overriding __init__", + get_fully_qualified_tp_name(vh.type->type).c_str()); + Py_DECREF(self); + return nullptr; + } + } + + return self; +} + +/// Cleanup the type-info for a pybind11-registered type. +extern "C" inline void pybind11_meta_dealloc(PyObject *obj) { + with_internals([obj](internals &internals) { + auto *type = (PyTypeObject *) obj; + + // A pybind11-registered type will: + // 1) be found in internals.registered_types_py + // 2) have exactly one associated `detail::type_info` + auto found_type = internals.registered_types_py.find(type); + if (found_type != internals.registered_types_py.end() && found_type->second.size() == 1 + && found_type->second[0]->type == type) { + + auto *tinfo = found_type->second[0]; + auto tindex = std::type_index(*tinfo->cpptype); + internals.direct_conversions.erase(tindex); + + if (tinfo->module_local) { + get_local_internals().registered_types_cpp.erase(tindex); + } else { + internals.registered_types_cpp.erase(tindex); + } + internals.registered_types_py.erase(tinfo->type); + + // Actually just `std::erase_if`, but that's only available in C++20 + auto &cache = internals.inactive_override_cache; + for (auto it = cache.begin(), last = cache.end(); it != last;) { + if (it->first == (PyObject *) tinfo->type) { + it = cache.erase(it); + } else { + ++it; + } + } + + delete tinfo; + } + }); + + PyType_Type.tp_dealloc(obj); +} + +/** This metaclass is assigned by default to all pybind11 types and is required in order + for static properties to function correctly. Users may override this using `py::metaclass`. + Return value: New reference. */ +inline PyTypeObject *make_default_metaclass() { + constexpr auto *name = "pybind11_type"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) { + pybind11_fail("make_default_metaclass(): error allocating metaclass!"); + } + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto *type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyType_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_call = pybind11_meta_call; + + type->tp_setattro = pybind11_meta_setattro; + type->tp_getattro = pybind11_meta_getattro; + + type->tp_dealloc = pybind11_meta_dealloc; + + if (PyType_Ready(type) < 0) { + pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); + } + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +/// For multiple inheritance types we need to recursively register/deregister base pointers for any +/// base classes with pointers that are difference from the instance value pointer so that we can +/// correctly recognize an offset base class pointer. This calls a function with any offset base +/// ptrs. +inline void traverse_offset_bases(void *valueptr, + const detail::type_info *tinfo, + instance *self, + bool (*f)(void * /*parentptr*/, instance * /*self*/)) { + for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { + if (auto *parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { + for (auto &c : parent_tinfo->implicit_casts) { + if (c.first == tinfo->cpptype) { + auto *parentptr = c.second(valueptr); + if (parentptr != valueptr) { + f(parentptr, self); + } + traverse_offset_bases(parentptr, parent_tinfo, self, f); + break; + } + } + } + } +} + +inline bool register_instance_impl(void *ptr, instance *self) { + with_instance_map(ptr, [&](instance_map &instances) { instances.emplace(ptr, self); }); + return true; // unused, but gives the same signature as the deregister func +} +inline bool deregister_instance_impl(void *ptr, instance *self) { + return with_instance_map(ptr, [&](instance_map &instances) { + auto range = instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + if (self == it->second) { + instances.erase(it); + return true; + } + } + return false; + }); +} + +inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { + register_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) { + traverse_offset_bases(valptr, tinfo, self, register_instance_impl); + } +} + +inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { + bool ret = deregister_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) { + traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); + } + return ret; +} + +/// Instance creation function for all pybind11 types. It allocates the internal instance layout +/// for holding C++ objects and holders. Allocation is done lazily (the first time the instance is +/// cast to a reference or pointer), and initialization is done by an `__init__` function. +inline PyObject *make_new_instance(PyTypeObject *type) { +#if defined(PYPY_VERSION) + // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first + // inherited object is a plain Python type (i.e. not derived from an extension type). Fix it. + ssize_t instance_size = static_cast(sizeof(instance)); + if (type->tp_basicsize < instance_size) { + type->tp_basicsize = instance_size; + } +#endif + PyObject *self = type->tp_alloc(type, 0); + auto *inst = reinterpret_cast(self); + // Allocate the value/holder internals: + inst->allocate_layout(); + + return self; +} + +/// Instance creation function for all pybind11 types. It only allocates space for the +/// C++ object, but doesn't call the constructor -- an `__init__` function must do that. +extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { + return make_new_instance(type); +} + +/// An `__init__` function constructs the C++ object. Users should provide at least one +/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the +/// following default function will be used which simply throws an exception. +extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { + PyTypeObject *type = Py_TYPE(self); + std::string msg = get_fully_qualified_tp_name(type) + ": No constructor defined!"; + set_error(PyExc_TypeError, msg.c_str()); + return -1; +} + +inline void add_patient(PyObject *nurse, PyObject *patient) { + auto *instance = reinterpret_cast(nurse); + instance->has_patients = true; + Py_INCREF(patient); + + with_internals([&](internals &internals) { internals.patients[nurse].push_back(patient); }); +} + +inline void clear_patients(PyObject *self) { + auto *instance = reinterpret_cast(self); + std::vector patients; + + with_internals([&](internals &internals) { + auto pos = internals.patients.find(self); + + if (pos == internals.patients.end()) { + pybind11_fail( + "FATAL: Internal consistency check failed: Invalid clear_patients() call."); + } + + // Clearing the patients can cause more Python code to run, which + // can invalidate the iterator. Extract the vector of patients + // from the unordered_map first. + patients = std::move(pos->second); + internals.patients.erase(pos); + }); + + instance->has_patients = false; + for (PyObject *&patient : patients) { + Py_CLEAR(patient); + } +} + +/// Clears all internal data from the instance and removes it from registered instances in +/// preparation for deallocation. +inline void clear_instance(PyObject *self) { + auto *instance = reinterpret_cast(self); + + // Deallocate any values/holders, if present: + for (auto &v_h : values_and_holders(instance)) { + if (v_h) { + + // We have to deregister before we call dealloc because, for virtual MI types, we still + // need to be able to get the parent pointers. + if (v_h.instance_registered() + && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) { + pybind11_fail( + "pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); + } + + if (instance->owned || v_h.holder_constructed()) { + v_h.type->dealloc(v_h); + } + } + } + // Deallocate the value/holder layout internals: + instance->deallocate_layout(); + + if (instance->weakrefs) { + PyObject_ClearWeakRefs(self); + } + + PyObject **dict_ptr = _PyObject_GetDictPtr(self); + if (dict_ptr) { + Py_CLEAR(*dict_ptr); + } + + if (instance->has_patients) { + clear_patients(self); + } +} + +/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` +/// to destroy the C++ object itself, while the rest is Python bookkeeping. +extern "C" inline void pybind11_object_dealloc(PyObject *self) { + auto *type = Py_TYPE(self); + + // If this is a GC tracked object, untrack it first + // Note that the track call is implicitly done by the + // default tp_alloc, which we never override. + if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) { + PyObject_GC_UnTrack(self); + } + + clear_instance(self); + + type->tp_free(self); + +#if PY_VERSION_HEX < 0x03080000 + // `type->tp_dealloc != pybind11_object_dealloc` means that we're being called + // as part of a derived type's dealloc, in which case we're not allowed to decref + // the type here. For cross-module compatibility, we shouldn't compare directly + // with `pybind11_object_dealloc`, but with the common one stashed in internals. + auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base; + if (type->tp_dealloc == pybind11_object_type->tp_dealloc) + Py_DECREF(type); +#else + // This was not needed before Python 3.8 (Python issue 35810) + // https://github.com/pybind/pybind11/issues/1946 + Py_DECREF(type); +#endif +} + +std::string error_string(); + +/** Create the type which can be used as a common base for all classes. This is + needed in order to satisfy Python's requirements for multiple inheritance. + Return value: New reference. */ +inline PyObject *make_object_base_type(PyTypeObject *metaclass) { + constexpr auto *name = "pybind11_object"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) { + pybind11_fail("make_object_base_type(): error allocating type!"); + } + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto *type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyBaseObject_Type); + type->tp_basicsize = static_cast(sizeof(instance)); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_new = pybind11_object_new; + type->tp_init = pybind11_object_init; + type->tp_dealloc = pybind11_object_dealloc; + + /* Support weak references (needed for the keep_alive feature) */ + type->tp_weaklistoffset = offsetof(instance, weakrefs); + + if (PyType_Ready(type) < 0) { + pybind11_fail("PyType_Ready failed in make_object_base_type(): " + error_string()); + } + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + return (PyObject *) heap_type; +} + +/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. +extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { +#if PY_VERSION_HEX >= 0x030D0000 + PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#endif +// https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse +#if PY_VERSION_HEX >= 0x03090000 + Py_VISIT(Py_TYPE(self)); +#endif + return 0; +} + +/// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" inline int pybind11_clear(PyObject *self) { +#if PY_VERSION_HEX >= 0x030D0000 + PyObject_ClearManagedDict(self); +#else + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#endif + return 0; +} + +/// Give instances of this type a `__dict__` and opt into garbage collection. +inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { + auto *type = &heap_type->ht_type; + type->tp_flags |= Py_TPFLAGS_HAVE_GC; +#if PY_VERSION_HEX < 0x030B0000 + type->tp_dictoffset = type->tp_basicsize; // place dict at the end + type->tp_basicsize += (ssize_t) sizeof(PyObject *); // and allocate enough space for it +#else + type->tp_flags |= Py_TPFLAGS_MANAGED_DICT; +#endif + type->tp_traverse = pybind11_traverse; + type->tp_clear = pybind11_clear; + + static PyGetSetDef getset[] + = {{"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + type->tp_getset = getset; +} + +/// buffer_protocol: Fill in the view as specified by flags. +extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { + // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). + type_info *tinfo = nullptr; + for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { + tinfo = get_type_info((PyTypeObject *) type.ptr()); + if (tinfo && tinfo->get_buffer) { + break; + } + } + if (view == nullptr || !tinfo || !tinfo->get_buffer) { + if (view) { + view->obj = nullptr; + } + set_error(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); + return -1; + } + std::memset(view, 0, sizeof(Py_buffer)); + buffer_info *info = nullptr; + try { + info = tinfo->get_buffer(obj, tinfo->get_buffer_data); + } catch (...) { + try_translate_exceptions(); + raise_from(PyExc_BufferError, "Error getting buffer"); + return -1; + } + if (info == nullptr) { + pybind11_fail("FATAL UNEXPECTED SITUATION: tinfo->get_buffer() returned nullptr."); + } + + if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE && info->readonly) { + delete info; + // view->obj = nullptr; // Was just memset to 0, so not necessary + set_error(PyExc_BufferError, "Writable buffer requested for readonly storage"); + return -1; + } + view->obj = obj; + view->ndim = 1; + view->internal = info; + view->buf = info->ptr; + view->itemsize = info->itemsize; + view->len = view->itemsize; + for (auto s : info->shape) { + view->len *= s; + } + view->readonly = static_cast(info->readonly); + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(info->format.c_str()); + } + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = (int) info->ndim; + view->strides = info->strides.data(); + view->shape = info->shape.data(); + } + Py_INCREF(view->obj); + return 0; +} + +/// buffer_protocol: Release the resources of the buffer. +extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { + delete (buffer_info *) view->internal; +} + +/// Give this type a buffer interface. +inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { + heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; + + heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; + heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; +} + +/** Create a brand new Python type according to the `type_record` specification. + Return value: New reference. */ +inline PyObject *make_new_python_type(const type_record &rec) { + auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); + + auto qualname = name; + if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) { + qualname = reinterpret_steal( + PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); + } + + object module_; + if (rec.scope) { + if (hasattr(rec.scope, "__module__")) { + module_ = rec.scope.attr("__module__"); + } else if (hasattr(rec.scope, "__name__")) { + module_ = rec.scope.attr("__name__"); + } + } + + const auto *full_name = c_str( +#if !defined(PYPY_VERSION) + module_ ? str(module_).cast() + "." + rec.name : +#endif + rec.name); + + char *tp_doc = nullptr; + if (rec.doc && options::show_user_defined_docstrings()) { + /* Allocate memory for docstring (Python will free this later on) */ + size_t size = std::strlen(rec.doc) + 1; +#if PY_VERSION_HEX >= 0x030D0000 + tp_doc = (char *) PyMem_MALLOC(size); +#else + tp_doc = (char *) PyObject_MALLOC(size); +#endif + std::memcpy((void *) tp_doc, rec.doc, size); + } + + auto &internals = get_internals(); + auto bases = tuple(rec.bases); + auto *base = (bases.empty()) ? internals.instance_base : bases[0].ptr(); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *metaclass + = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() : internals.default_metaclass; + + auto *heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) { + pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); + } + + heap_type->ht_name = name.release().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = qualname.inc_ref().ptr(); +#endif + + auto *type = &heap_type->ht_type; + type->tp_name = full_name; + type->tp_doc = tp_doc; + type->tp_base = type_incref((PyTypeObject *) base); + type->tp_basicsize = static_cast(sizeof(instance)); + if (!bases.empty()) { + type->tp_bases = bases.release().ptr(); + } + + /* Don't inherit base __init__ */ + type->tp_init = pybind11_object_init; + + /* Supported protocols */ + type->tp_as_number = &heap_type->as_number; + type->tp_as_sequence = &heap_type->as_sequence; + type->tp_as_mapping = &heap_type->as_mapping; + type->tp_as_async = &heap_type->as_async; + + /* Flags */ + type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE; + if (!rec.is_final) { + type->tp_flags |= Py_TPFLAGS_BASETYPE; + } + + if (rec.dynamic_attr) { + enable_dynamic_attributes(heap_type); + } + + if (rec.buffer_protocol) { + enable_buffer_protocol(heap_type); + } + + if (rec.custom_type_setup_callback) { + rec.custom_type_setup_callback(heap_type); + } + + if (PyType_Ready(type) < 0) { + pybind11_fail(std::string(rec.name) + ": PyType_Ready failed: " + error_string()); + } + + assert(!rec.dynamic_attr || PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + + /* Register type with the parent scope */ + if (rec.scope) { + setattr(rec.scope, rec.name, (PyObject *) type); + } else { + Py_INCREF(type); // Keep it alive forever (reference leak) + } + + if (module_) { // Needed by pydoc + setattr((PyObject *) type, "__module__", module_); + } + + PYBIND11_SET_OLDPY_QUALNAME(type, qualname); + + return (PyObject *) type; +} + +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/common.h b/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/common.h new file mode 100644 index 0000000000000000000000000000000000000000..c51d1d60bc739bfd67efc64fb517f0f0c2d7592a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/common.h @@ -0,0 +1,1287 @@ +/* + pybind11/detail/common.h -- Basic macros + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#define PYBIND11_VERSION_MAJOR 2 +#define PYBIND11_VERSION_MINOR 13 +#define PYBIND11_VERSION_PATCH 6 + +// Similar to Python's convention: https://docs.python.org/3/c-api/apiabiversion.html +// Additional convention: 0xD = dev +#define PYBIND11_VERSION_HEX 0x020D0600 + +// Define some generic pybind11 helper macros for warning management. +// +// Note that compiler-specific push/pop pairs are baked into the +// PYBIND11_NAMESPACE_BEGIN/PYBIND11_NAMESPACE_END pair of macros. Therefore manual +// PYBIND11_WARNING_PUSH/PYBIND11_WARNING_POP are usually only needed in `#include` sections. +// +// If you find you need to suppress a warning, please try to make the suppression as local as +// possible using these macros. Please also be sure to push/pop with the pybind11 macros. Please +// only use compiler specifics if you need to check specific versions, e.g. Apple Clang vs. vanilla +// Clang. +#if defined(_MSC_VER) +# define PYBIND11_COMPILER_MSVC +# define PYBIND11_PRAGMA(...) __pragma(__VA_ARGS__) +# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(warning(push)) +# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(warning(pop)) +#elif defined(__INTEL_COMPILER) +# define PYBIND11_COMPILER_INTEL +# define PYBIND11_PRAGMA(...) _Pragma(#__VA_ARGS__) +# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(warning push) +# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(warning pop) +#elif defined(__clang__) +# define PYBIND11_COMPILER_CLANG +# define PYBIND11_PRAGMA(...) _Pragma(#__VA_ARGS__) +# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(clang diagnostic push) +# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(clang diagnostic push) +#elif defined(__GNUC__) +# define PYBIND11_COMPILER_GCC +# define PYBIND11_PRAGMA(...) _Pragma(#__VA_ARGS__) +# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(GCC diagnostic push) +# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(GCC diagnostic pop) +#endif + +#ifdef PYBIND11_COMPILER_MSVC +# define PYBIND11_WARNING_DISABLE_MSVC(name) PYBIND11_PRAGMA(warning(disable : name)) +#else +# define PYBIND11_WARNING_DISABLE_MSVC(name) +#endif + +#ifdef PYBIND11_COMPILER_CLANG +# define PYBIND11_WARNING_DISABLE_CLANG(name) PYBIND11_PRAGMA(clang diagnostic ignored name) +#else +# define PYBIND11_WARNING_DISABLE_CLANG(name) +#endif + +#ifdef PYBIND11_COMPILER_GCC +# define PYBIND11_WARNING_DISABLE_GCC(name) PYBIND11_PRAGMA(GCC diagnostic ignored name) +#else +# define PYBIND11_WARNING_DISABLE_GCC(name) +#endif + +#ifdef PYBIND11_COMPILER_INTEL +# define PYBIND11_WARNING_DISABLE_INTEL(name) PYBIND11_PRAGMA(warning disable name) +#else +# define PYBIND11_WARNING_DISABLE_INTEL(name) +#endif + +#define PYBIND11_NAMESPACE_BEGIN(name) \ + namespace name { \ + PYBIND11_WARNING_PUSH + +#define PYBIND11_NAMESPACE_END(name) \ + PYBIND11_WARNING_POP \ + } + +// Robust support for some features and loading modules compiled against different pybind versions +// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute +// on the main `pybind11` namespace. +#if !defined(PYBIND11_NAMESPACE) +# ifdef __GNUG__ +# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden"))) +# else +# define PYBIND11_NAMESPACE pybind11 +# endif +#endif + +#if !(defined(_MSC_VER) && __cplusplus == 199711L) +# if __cplusplus >= 201402L +# define PYBIND11_CPP14 +# if __cplusplus >= 201703L +# define PYBIND11_CPP17 +# if __cplusplus >= 202002L +# define PYBIND11_CPP20 +// Please update tests/pybind11_tests.cpp `cpp_std()` when adding a macro here. +# endif +# endif +# endif +#elif defined(_MSC_VER) && __cplusplus == 199711L +// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully +// implemented). Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 +// or newer. +# if _MSVC_LANG >= 201402L +# define PYBIND11_CPP14 +# if _MSVC_LANG > 201402L +# define PYBIND11_CPP17 +# if _MSVC_LANG >= 202002L +# define PYBIND11_CPP20 +# endif +# endif +# endif +#endif + +#if defined(PYBIND11_CPP20) +# define PYBIND11_CONSTINIT constinit +# define PYBIND11_DTOR_CONSTEXPR constexpr +#else +# define PYBIND11_CONSTINIT +# define PYBIND11_DTOR_CONSTEXPR +#endif + +// Compiler version assertions +#if defined(__INTEL_COMPILER) +# if __INTEL_COMPILER < 1800 +# error pybind11 requires Intel C++ compiler v18 or newer +# elif __INTEL_COMPILER < 1900 && defined(PYBIND11_CPP14) +# error pybind11 supports only C++11 with Intel C++ compiler v18. Use v19 or newer for C++14. +# endif +/* The following pragma cannot be pop'ed: + https://community.intel.com/t5/Intel-C-Compiler/Inline-and-no-inline-warning/td-p/1216764 */ +# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline" +#elif defined(__clang__) && !defined(__apple_build_version__) +# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3) +# error pybind11 requires clang 3.3 or newer +# endif +#elif defined(__clang__) +// Apple changes clang version macros to its Xcode version; the first Xcode release based on +// (upstream) clang 3.3 was Xcode 5: +# if __clang_major__ < 5 +# error pybind11 requires Xcode/clang 5.0 or newer +# endif +#elif defined(__GNUG__) +# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8) +# error pybind11 requires gcc 4.8 or newer +# endif +#elif defined(_MSC_VER) +# if _MSC_VER < 1910 +# error pybind11 2.10+ requires MSVC 2017 or newer +# endif +#endif + +#if !defined(PYBIND11_EXPORT) +# if defined(WIN32) || defined(_WIN32) +# define PYBIND11_EXPORT __declspec(dllexport) +# else +# define PYBIND11_EXPORT __attribute__((visibility("default"))) +# endif +#endif + +#if !defined(PYBIND11_EXPORT_EXCEPTION) +# if defined(__apple_build_version__) +# define PYBIND11_EXPORT_EXCEPTION PYBIND11_EXPORT +# else +# define PYBIND11_EXPORT_EXCEPTION +# endif +#endif + +// For CUDA, GCC7, GCC8: +// PYBIND11_NOINLINE_FORCED is incompatible with `-Wattributes -Werror`. +// When defining PYBIND11_NOINLINE_FORCED, it is best to also use `-Wno-attributes`. +// However, the measured shared-library size saving when using noinline are only +// 1.7% for CUDA, -0.2% for GCC7, and 0.0% for GCC8 (using -DCMAKE_BUILD_TYPE=MinSizeRel, +// the default under pybind11/tests). +#if !defined(PYBIND11_NOINLINE_FORCED) \ + && (defined(__CUDACC__) || (defined(__GNUC__) && (__GNUC__ == 7 || __GNUC__ == 8))) +# define PYBIND11_NOINLINE_DISABLED +#endif + +// The PYBIND11_NOINLINE macro is for function DEFINITIONS. +// In contrast, FORWARD DECLARATIONS should never use this macro: +// https://stackoverflow.com/questions/9317473/forward-declaration-of-inline-functions +#if defined(PYBIND11_NOINLINE_DISABLED) // Option for maximum portability and experimentation. +# define PYBIND11_NOINLINE inline +#elif defined(_MSC_VER) +# define PYBIND11_NOINLINE __declspec(noinline) inline +#else +# define PYBIND11_NOINLINE __attribute__((noinline)) inline +#endif + +#if defined(__MINGW32__) +// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared +// whether it is used or not +# define PYBIND11_DEPRECATED(reason) +#elif defined(PYBIND11_CPP14) +# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]] +#else +# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason))) +#endif + +#if defined(PYBIND11_CPP17) +# define PYBIND11_MAYBE_UNUSED [[maybe_unused]] +#elif defined(_MSC_VER) && !defined(__clang__) +# define PYBIND11_MAYBE_UNUSED +#else +# define PYBIND11_MAYBE_UNUSED __attribute__((__unused__)) +#endif + +/* Don't let Python.h #define (v)snprintf as macro because they are implemented + properly in Visual Studio since 2015. */ +#if defined(_MSC_VER) +# define HAVE_SNPRINTF 1 +#endif + +/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode +#if defined(_MSC_VER) +PYBIND11_WARNING_PUSH +PYBIND11_WARNING_DISABLE_MSVC(4505) +// C4505: 'PySlice_GetIndicesEx': unreferenced local function has been removed (PyPy only) +# if defined(_DEBUG) && !defined(Py_DEBUG) +// Workaround for a VS 2022 issue. +// NOTE: This workaround knowingly violates the Python.h include order requirement: +// https://docs.python.org/3/c-api/intro.html#include-files +// See https://github.com/pybind/pybind11/pull/3497 for full context. +# include +# if _MSVC_STL_VERSION >= 143 +# include +# endif +# define PYBIND11_DEBUG_MARKER +# undef _DEBUG +# endif +#endif + +// https://en.cppreference.com/w/c/chrono/localtime +#if defined(__STDC_LIB_EXT1__) && !defined(__STDC_WANT_LIB_EXT1__) +# define __STDC_WANT_LIB_EXT1__ +#endif + +#ifdef __has_include +// std::optional (but including it in c++14 mode isn't allowed) +# if defined(PYBIND11_CPP17) && __has_include() +# define PYBIND11_HAS_OPTIONAL 1 +# endif +// std::experimental::optional (but not allowed in c++11 mode) +# if defined(PYBIND11_CPP14) && (__has_include() && \ + !__has_include()) +# define PYBIND11_HAS_EXP_OPTIONAL 1 +# endif +// std::variant +# if defined(PYBIND11_CPP17) && __has_include() +# define PYBIND11_HAS_VARIANT 1 +# endif +#elif defined(_MSC_VER) && defined(PYBIND11_CPP17) +# define PYBIND11_HAS_OPTIONAL 1 +# define PYBIND11_HAS_VARIANT 1 +#endif + +#if defined(PYBIND11_CPP17) +# if defined(__has_include) +# if __has_include() +# define PYBIND11_HAS_STRING_VIEW +# endif +# elif defined(_MSC_VER) +# define PYBIND11_HAS_STRING_VIEW +# endif +#endif + +#include +#if PY_VERSION_HEX < 0x03070000 +# error "PYTHON < 3.7 IS UNSUPPORTED. pybind11 v2.12 was the last to support Python 3.6." +#endif +#include +#include + +/* Python #defines overrides on all sorts of core functions, which + tends to weak havok in C++ codebases that expect these to work + like regular functions (potentially with several overloads) */ +#if defined(isalnum) +# undef isalnum +# undef isalpha +# undef islower +# undef isspace +# undef isupper +# undef tolower +# undef toupper +#endif + +#if defined(copysign) +# undef copysign +#endif + +#if defined(PYBIND11_NUMPY_1_ONLY) +# define PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED +#endif + +#if defined(PYPY_VERSION) && !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT) +# define PYBIND11_SIMPLE_GIL_MANAGEMENT +#endif + +#if defined(_MSC_VER) +# if defined(PYBIND11_DEBUG_MARKER) +# define _DEBUG +# undef PYBIND11_DEBUG_MARKER +# endif +PYBIND11_WARNING_POP +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__has_include) +# if __has_include() +# include +# endif +#endif + +// Must be after including or one of the other headers specified by the standard +#if defined(__cpp_lib_char8_t) && __cpp_lib_char8_t >= 201811L +# define PYBIND11_HAS_U8STRING +#endif + +// See description of PR #4246: +#if !defined(PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF) && !defined(NDEBUG) \ + && !defined(PYPY_VERSION) && !defined(PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF) +# define PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF +#endif + +// #define PYBIND11_STR_LEGACY_PERMISSIVE +// If DEFINED, pybind11::str can hold PyUnicodeObject or PyBytesObject +// (probably surprising and never documented, but this was the +// legacy behavior until and including v2.6.x). As a side-effect, +// pybind11::isinstance() is true for both pybind11::str and +// pybind11::bytes. +// If UNDEFINED, pybind11::str can only hold PyUnicodeObject, and +// pybind11::isinstance() is true only for pybind11::str. +// However, for Python 2 only (!), the pybind11::str caster +// implicitly decoded bytes to PyUnicodeObject. This was to ease +// the transition from the legacy behavior to the non-permissive +// behavior. + +/// Compatibility macros for Python 2 / Python 3 versions TODO: remove +#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr) +#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check +#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION +#define PYBIND11_BYTES_CHECK PyBytes_Check +#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString +#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize +#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize +#define PYBIND11_BYTES_AS_STRING PyBytes_AsString +#define PYBIND11_BYTES_SIZE PyBytes_Size +#define PYBIND11_LONG_CHECK(o) PyLong_Check(o) +#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o) +#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) (o)) +#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) (o)) +#define PYBIND11_BYTES_NAME "bytes" +#define PYBIND11_STRING_NAME "str" +#define PYBIND11_SLICE_OBJECT PyObject +#define PYBIND11_FROM_STRING PyUnicode_FromString +#define PYBIND11_STR_TYPE ::pybind11::str +#define PYBIND11_BOOL_ATTR "__bool__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool) +#define PYBIND11_BUILTINS_MODULE "builtins" +// Providing a separate declaration to make Clang's -Wmissing-prototypes happy. +// See comment for PYBIND11_MODULE below for why this is marked "maybe unused". +#define PYBIND11_PLUGIN_IMPL(name) \ + extern "C" PYBIND11_MAYBE_UNUSED PYBIND11_EXPORT PyObject *PyInit_##name(); \ + extern "C" PYBIND11_EXPORT PyObject *PyInit_##name() + +#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code +#define PYBIND11_STRINGIFY(x) #x +#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x) +#define PYBIND11_CONCAT(first, second) first##second +#define PYBIND11_ENSURE_INTERNALS_READY pybind11::detail::get_internals(); + +#define PYBIND11_CHECK_PYTHON_VERSION \ + { \ + const char *compiled_ver \ + = PYBIND11_TOSTRING(PY_MAJOR_VERSION) "." PYBIND11_TOSTRING(PY_MINOR_VERSION); \ + const char *runtime_ver = Py_GetVersion(); \ + size_t len = std::strlen(compiled_ver); \ + if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \ + || (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \ + PyErr_Format(PyExc_ImportError, \ + "Python version mismatch: module was compiled for Python %s, " \ + "but the interpreter version is incompatible: %s.", \ + compiled_ver, \ + runtime_ver); \ + return nullptr; \ + } \ + } + +#define PYBIND11_CATCH_INIT_EXCEPTIONS \ + catch (pybind11::error_already_set & e) { \ + pybind11::raise_from(e, PyExc_ImportError, "initialization failed"); \ + return nullptr; \ + } \ + catch (const std::exception &e) { \ + ::pybind11::set_error(PyExc_ImportError, e.what()); \ + return nullptr; \ + } + +/** \rst + ***Deprecated in favor of PYBIND11_MODULE*** + + This macro creates the entry point that will be invoked when the Python interpreter + imports a plugin library. Please create a `module_` in the function body and return + the pointer to its underlying Python object at the end. + + .. code-block:: cpp + + PYBIND11_PLUGIN(example) { + pybind11::module_ m("example", "pybind11 example plugin"); + /// Set up bindings here + return m.ptr(); + } +\endrst */ +#define PYBIND11_PLUGIN(name) \ + PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \ + static PyObject *pybind11_init(); \ + PYBIND11_PLUGIN_IMPL(name) { \ + PYBIND11_CHECK_PYTHON_VERSION \ + PYBIND11_ENSURE_INTERNALS_READY \ + try { \ + return pybind11_init(); \ + } \ + PYBIND11_CATCH_INIT_EXCEPTIONS \ + } \ + PyObject *pybind11_init() + +/** \rst + This macro creates the entry point that will be invoked when the Python interpreter + imports an extension module. The module name is given as the first argument and it + should not be in quotes. The second macro argument defines a variable of type + `py::module_` which can be used to initialize the module. + + The entry point is marked as "maybe unused" to aid dead-code detection analysis: + since the entry point is typically only looked up at runtime and not referenced + during translation, it would otherwise appear as unused ("dead") code. + + .. code-block:: cpp + + PYBIND11_MODULE(example, m) { + m.doc() = "pybind11 example module"; + + // Add bindings here + m.def("foo", []() { + return "Hello, World!"; + }); + } + + The third macro argument is optional (available since 2.13.0), and can be used to + mark the extension module as safe to run without the GIL under a free-threaded CPython + interpreter. Passing this argument has no effect on other interpreters. + + .. code-block:: cpp + + PYBIND11_MODULE(example, m, py::mod_gil_not_used()) { + m.doc() = "pybind11 example module safe to run without the GIL"; + + // Add bindings here + m.def("foo", []() { + return "Hello, Free-threaded World!"; + }); + } + +\endrst */ +PYBIND11_WARNING_PUSH +PYBIND11_WARNING_DISABLE_CLANG("-Wgnu-zero-variadic-macro-arguments") +#define PYBIND11_MODULE(name, variable, ...) \ + static ::pybind11::module_::module_def PYBIND11_CONCAT(pybind11_module_def_, name) \ + PYBIND11_MAYBE_UNUSED; \ + PYBIND11_MAYBE_UNUSED \ + static void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ &); \ + PYBIND11_PLUGIN_IMPL(name) { \ + PYBIND11_CHECK_PYTHON_VERSION \ + PYBIND11_ENSURE_INTERNALS_READY \ + auto m = ::pybind11::module_::create_extension_module( \ + PYBIND11_TOSTRING(name), \ + nullptr, \ + &PYBIND11_CONCAT(pybind11_module_def_, name), \ + ##__VA_ARGS__); \ + try { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + return m.ptr(); \ + } \ + PYBIND11_CATCH_INIT_EXCEPTIONS \ + } \ + void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ & (variable)) +PYBIND11_WARNING_POP + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +using ssize_t = Py_ssize_t; +using size_t = std::size_t; + +template +inline ssize_t ssize_t_cast(const IntType &val) { + static_assert(sizeof(IntType) <= sizeof(ssize_t), "Implicit narrowing is not permitted."); + return static_cast(val); +} + +/// Approach used to cast a previously unknown C++ instance into a Python object +enum class return_value_policy : uint8_t { + /** This is the default return value policy, which falls back to the policy + return_value_policy::take_ownership when the return value is a pointer. + Otherwise, it uses return_value::move or return_value::copy for rvalue + and lvalue references, respectively. See below for a description of what + all of these different policies do. */ + automatic = 0, + + /** As above, but use policy return_value_policy::reference when the return + value is a pointer. This is the default conversion policy for function + arguments when calling Python functions manually from C++ code (i.e. via + handle::operator()). You probably won't need to use this. */ + automatic_reference, + + /** Reference an existing object (i.e. do not create a new copy) and take + ownership. Python will call the destructor and delete operator when the + object's reference count reaches zero. Undefined behavior ensues when + the C++ side does the same.. */ + take_ownership, + + /** Create a new copy of the returned object, which will be owned by + Python. This policy is comparably safe because the lifetimes of the two + instances are decoupled. */ + copy, + + /** Use std::move to move the return value contents into a new instance + that will be owned by Python. This policy is comparably safe because the + lifetimes of the two instances (move source and destination) are + decoupled. */ + move, + + /** Reference an existing object, but do not take ownership. The C++ side + is responsible for managing the object's lifetime and deallocating it + when it is no longer used. Warning: undefined behavior will ensue when + the C++ side deletes an object that is still referenced and used by + Python. */ + reference, + + /** This policy only applies to methods and properties. It references the + object without taking ownership similar to the above + return_value_policy::reference policy. In contrast to that policy, the + function or property's implicit this argument (called the parent) is + considered to be the owner of the return value (the child). + pybind11 then couples the lifetime of the parent to the child via a + reference relationship that ensures that the parent cannot be garbage + collected while Python is still using the child. More advanced + variations of this scheme are also possible using combinations of + return_value_policy::reference and the keep_alive call policy */ + reference_internal +}; + +PYBIND11_NAMESPACE_BEGIN(detail) + +inline static constexpr int log2(size_t n, int k = 0) { + return (n <= 1) ? k : log2(n >> 1, k + 1); +} + +// Returns the size as a multiple of sizeof(void *), rounded up. +inline static constexpr size_t size_in_ptrs(size_t s) { + return 1 + ((s - 1) >> log2(sizeof(void *))); +} + +/** + * The space to allocate for simple layout instance holders (see below) in multiple of the size of + * a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required + * to holder either a std::unique_ptr or std::shared_ptr (which is almost always + * sizeof(std::shared_ptr)). + */ +constexpr size_t instance_simple_holder_in_ptrs() { + static_assert(sizeof(std::shared_ptr) >= sizeof(std::unique_ptr), + "pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs"); + return size_in_ptrs(sizeof(std::shared_ptr)); +} + +// Forward declarations +struct type_info; +struct value_and_holder; + +struct nonsimple_values_and_holders { + void **values_and_holders; + uint8_t *status; +}; + +/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof') +struct instance { + PyObject_HEAD + /// Storage for pointers and holder; see simple_layout, below, for a description + union { + void *simple_value_holder[1 + instance_simple_holder_in_ptrs()]; + nonsimple_values_and_holders nonsimple; + }; + /// Weak references + PyObject *weakrefs; + /// If true, the pointer is owned which means we're free to manage it with a holder. + bool owned : 1; + /** + * An instance has two possible value/holder layouts. + * + * Simple layout (when this flag is true), means the `simple_value_holder` is set with a + * pointer and the holder object governing that pointer, i.e. [val1*][holder]. This layout is + * applied whenever there is no python-side multiple inheritance of bound C++ types *and* the + * type's holder will fit in the default space (which is large enough to hold either a + * std::unique_ptr or std::shared_ptr). + * + * Non-simple layout applies when using custom holders that require more space than + * `shared_ptr` (which is typically the size of two pointers), or when multiple inheritance is + * used on the python side. Non-simple layout allocates the required amount of memory to have + * multiple bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is + * set to a pointer to allocated space of the required space to hold a sequence of value + * pointers and holders followed `status`, a set of bit flags (1 byte each), i.e. + * [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple + * of `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the beginning of + * the [bb...] block (but not independently allocated). + * + * Status bits indicate whether the associated holder is constructed (& + * status_holder_constructed) and whether the value pointer is registered (& + * status_instance_registered) in `registered_instances`. + */ + bool simple_layout : 1; + /// For simple layout, tracks whether the holder has been constructed + bool simple_holder_constructed : 1; + /// For simple layout, tracks whether the instance is registered in `registered_instances` + bool simple_instance_registered : 1; + /// If true, get_internals().patients has an entry for this object + bool has_patients : 1; + + /// Initializes all of the above type/values/holders data (but not the instance values + /// themselves) + void allocate_layout(); + + /// Destroys/deallocates all of the above + void deallocate_layout(); + + /// Returns the value_and_holder wrapper for the given type (or the first, if `find_type` + /// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if + /// `throw_if_missing` is false. + value_and_holder get_value_and_holder(const type_info *find_type = nullptr, + bool throw_if_missing = true); + + /// Bit values for the non-simple status flags + static constexpr uint8_t status_holder_constructed = 1; + static constexpr uint8_t status_instance_registered = 2; +}; + +static_assert(std::is_standard_layout::value, + "Internal error: `pybind11::detail::instance` is not standard layout!"); + +/// from __cpp_future__ import (convenient aliases from C++14/17) +#if defined(PYBIND11_CPP14) +using std::conditional_t; +using std::enable_if_t; +using std::remove_cv_t; +using std::remove_reference_t; +#else +template +using enable_if_t = typename std::enable_if::type; +template +using conditional_t = typename std::conditional::type; +template +using remove_cv_t = typename std::remove_cv::type; +template +using remove_reference_t = typename std::remove_reference::type; +#endif + +#if defined(PYBIND11_CPP20) +using std::remove_cvref; +using std::remove_cvref_t; +#else +template +struct remove_cvref { + using type = remove_cv_t>; +}; +template +using remove_cvref_t = typename remove_cvref::type; +#endif + +/// Example usage: is_same_ignoring_cvref::value +template +using is_same_ignoring_cvref = std::is_same, U>; + +/// Index sequences +#if defined(PYBIND11_CPP14) +using std::index_sequence; +using std::make_index_sequence; +#else +template +struct index_sequence {}; +template +struct make_index_sequence_impl : make_index_sequence_impl {}; +template +struct make_index_sequence_impl<0, S...> { + using type = index_sequence; +}; +template +using make_index_sequence = typename make_index_sequence_impl::type; +#endif + +/// Make an index sequence of the indices of true arguments +template +struct select_indices_impl { + using type = ISeq; +}; +template +struct select_indices_impl, I, B, Bs...> + : select_indices_impl, index_sequence>, + I + 1, + Bs...> {}; +template +using select_indices = typename select_indices_impl, 0, Bs...>::type; + +/// Backports of std::bool_constant and std::negation to accommodate older compilers +template +using bool_constant = std::integral_constant; +template +struct negation : bool_constant {}; + +// PGI/Intel cannot detect operator delete with the "compatible" void_t impl, so +// using the new one (C++14 defect, so generally works on newer compilers, even +// if not in C++17 mode) +#if defined(__PGIC__) || defined(__INTEL_COMPILER) +template +using void_t = void; +#else +template +struct void_t_impl { + using type = void; +}; +template +using void_t = typename void_t_impl::type; +#endif + +/// Compile-time all/any/none of that check the boolean value of all template types +#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916)) +template +using all_of = bool_constant<(Ts::value && ...)>; +template +using any_of = bool_constant<(Ts::value || ...)>; +#elif !defined(_MSC_VER) +template +struct bools {}; +template +using all_of = std::is_same, bools>; +template +using any_of = negation...>>; +#else +// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit +// at a slight loss of compilation efficiency). +template +using all_of = std::conjunction; +template +using any_of = std::disjunction; +#endif +template +using none_of = negation>; + +template class... Predicates> +using satisfies_all_of = all_of...>; +template class... Predicates> +using satisfies_any_of = any_of...>; +template class... Predicates> +using satisfies_none_of = none_of...>; + +/// Strip the class from a method type +template +struct remove_class {}; +template +struct remove_class { + using type = R(A...); +}; +template +struct remove_class { + using type = R(A...); +}; +#ifdef __cpp_noexcept_function_type +template +struct remove_class { + using type = R(A...); +}; +template +struct remove_class { + using type = R(A...); +}; +#endif +/// Helper template to strip away type modifiers +template +struct intrinsic_type { + using type = T; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +using intrinsic_t = typename intrinsic_type::type; + +/// Helper type to replace 'void' in some expressions +struct void_type {}; + +/// Helper template which holds a list of types +template +struct type_list {}; + +/// Compile-time integer sum +#ifdef __cpp_fold_expressions +template +constexpr size_t constexpr_sum(Ts... ns) { + return (0 + ... + size_t{ns}); +} +#else +constexpr size_t constexpr_sum() { return 0; } +template +constexpr size_t constexpr_sum(T n, Ts... ns) { + return size_t{n} + constexpr_sum(ns...); +} +#endif + +PYBIND11_NAMESPACE_BEGIN(constexpr_impl) +/// Implementation details for constexpr functions +constexpr int first(int i) { return i; } +template +constexpr int first(int i, T v, Ts... vs) { + return v ? i : first(i + 1, vs...); +} + +constexpr int last(int /*i*/, int result) { return result; } +template +constexpr int last(int i, int result, T v, Ts... vs) { + return last(i + 1, v ? i : result, vs...); +} +PYBIND11_NAMESPACE_END(constexpr_impl) + +/// Return the index of the first type in Ts which satisfies Predicate. +/// Returns sizeof...(Ts) if none match. +template